diff --git a/build_vocab.py b/build_vocab.py index 577c2c1c..fabea1b2 100644 --- a/build_vocab.py +++ b/build_vocab.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from onmt.bin.build_vocab import main +from mammoth.bin.build_vocab import main if __name__ == "__main__": diff --git a/docs/source/CONTRIBUTING.md b/docs/source/CONTRIBUTING.md index 7ad1425b..717a3ca0 100644 --- a/docs/source/CONTRIBUTING.md +++ b/docs/source/CONTRIBUTING.md @@ -5,7 +5,7 @@ OpenNMT-py is a community developed project and we love developer contributions. ## Guidelines Before sending a PR, please do this checklist first: -- Please run `onmt/tests/pull_request_chk.sh` and fix any errors. When adding new functionality, also add tests to this script. Included checks: +- Please run `mammoth/tests/pull_request_chk.sh` and fix any errors. When adding new functionality, also add tests to this script. Included checks: 1. flake8 check for coding style; 2. unittest; 3. continuous integration tests listed in `.travis.yml`. diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md deleted file mode 100644 index 66a0b6b0..00000000 --- a/docs/source/FAQ.md +++ /dev/null @@ -1,37 +0,0 @@ -# Questions - -## What is the intuition behind fixed-length memory bank? -Specifically, for `lin` , the intuition behind the structured attention is to replace pooling over the hidden representations with multi-hop attentive representations (fixed length). What is the benefit for transforming source sequence representations into a fixed length memory bank? - -Push the model to be more language-agnostic. Sentence length tends to be language dependent. For example, French tends to produce longer sentences than English. - -Does the attention in attention bridge act as an enhancement of encoder? Will the attention bridge bring any benefits to decoders? -1. If we view attention bridge as a part of encoder, will the overall model be a partially shared encoder (separate lower layers and shared attention bridge) + separate decoders? - -If the shared attention is viewed is a part of encoder for many2one translation and a part of decoder for one2many translation, the shared attention module encoder some language-independent information to enhance encoding or decoding? - -## Models are saved with encoder, decoder, and generator. What is generator? -The generator contains Linear + activation (softmax or sparsesoftmax). - -### Why we need to separately save “generator”? -It seems unnecessary to separate the generator. Activation functions do not contain trainable parameters. - - -## What is the difference between `intermediate_output` and `encoder_output`? [🔗](./onmt/attention_bridge.py#L91) - -`intermediate_output` is the intermediate output of stacked n-layered attention bridges. `encoder_output` is literally the output of encoder, which was reused in the n-layered `PerceiverAttentionBridgeLayer`. - -For `PerceiverAttentionBridgeLayer` where the encoder output is projected into fixed length via `lattent_array`. But why? - -For `PerceiverAttentionBridgeLayer` : - -`intermediate_output` and `encoder_output` are used as: - -```python - S, B, F = encoder_output.shape - if intermediate_output is not None: - cross_query = intermediate_output - else: - cross_query = self.latent_array.unsqueeze(0).expand(B, -1, -1) - encoder_output = encoder_output.transpose(0, 1) -``` \ No newline at end of file diff --git a/docs/source/attention_bridges.md b/docs/source/attention_bridges.md index 0080a85f..3b014dbd 100644 --- a/docs/source/attention_bridges.md +++ b/docs/source/attention_bridges.md @@ -1,7 +1,7 @@ # Attention Bridge -The embeddings are generated through the self-attention mechanism ([Attention Bridge](./onmt/attention_bridge.py)) of the encoder and establish a connection with language-specific decoders that focus their attention on these embeddings. This is why they are referred to as 'bridges'. This architectural element serves to link the encoded information with the decoding process, enhancing the flow of information between different stages of language processing. +The embeddings are generated through the self-attention mechanism ([Attention Bridge](./mammoth/modules/attention_bridge.py)) of the encoder and establish a connection with language-specific decoders that focus their attention on these embeddings. This is why they are referred to as 'bridges'. This architectural element serves to link the encoded information with the decoding process, enhancing the flow of information between different stages of language processing. There are five types of attention mechanism implemented: @@ -61,7 +61,7 @@ The `PerceiverAttentionBridgeLayer` involves a multi-headed dot product self-att 3. **Linear Layer**: After normalization, the data is fed into a linear layer. This linear transformation can be seen as a learned projection of the attention-weighted data into a new space. -4. **ReLU Activation**: The output of the linear layer undergoes the Rectified Linear Unit (ReLU) activation function. +4. **ReLU Activation**: The output of the linear layer undergoes the Rectified Linear Unit (ReLU) activation function. 5. **Linear Layer (Second)**: Another linear layer is applied to the ReLU-activated output. @@ -72,11 +72,11 @@ The `PerceiverAttentionBridgeLayer` involves a multi-headed dot product self-att The process described involves dot product self-attention. The steps are as follows: 1. **Input Transformation**: Given an input matrix $\mathbf{H} \in \mathbb{R}^{d_h \times n}$, two sets of learned weight matrices are used to transform the input. These weight matrices are $\mathbf{W}_1 \in \mathbb{R}^{d_h \times d_a}$ and $\mathbf{W}_2 \in \mathbb{R}^{d_h \times d_a}$. The multiplication of $\mathbf{H}$ with $\mathbf{W}_1$ and $\mathbf{W}_2$ produces matrices $\mathbf{V}$ and $\mathbf{K}$, respectively: - + - $\mathbf{V} = \mathbf{H} \mathbf{W}_1$ - $\mathbf{K} = \mathbf{H} \mathbf{W}_2$ -2. **Attention Calculation**: The core attention calculation involves three matrices: $\mathbf{Q} \in \mathbb{R}^{d_h \times n}$, $\mathbf{K}$ (calculated previously), and $\mathbf{V}$ (calculated previously). The dot product of $\mathbf{Q}$ and $\mathbf{K}^\top$ is divided by the square root of the dimensionality of the input features ($\sqrt{d_h}$). +2. **Attention Calculation**: The core attention calculation involves three matrices: $\mathbf{Q} \in \mathbb{R}^{d_h \times n}$, $\mathbf{K}$ (calculated previously), and $\mathbf{V}$ (calculated previously). The dot product of $\mathbf{Q}$ and $\mathbf{K}^\top$ is divided by the square root of the dimensionality of the input features ($\sqrt{d_h}$). The final attended output is calculated by multiplying the attention weights with the $\mathbf{V}$ matrix: $\mathbf{H}^\prime = \operatorname{Softmax}(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_h}})\mathbf{V}$ @@ -86,5 +86,4 @@ The TransformerEncoderLayer employs multi-headed dot product self-attention (by ## FeedForwardAttentionBridgeLayer -The `FeedForwardAttentionBridgeLayer` module applies a sequence of linear transformations and `ReLU` activations to the input data, followed by an attention bridge normalization, enhancing the connectivity between different parts of the model. - +The `FeedForwardAttentionBridgeLayer` module applies a sequence of linear transformations and `ReLU` activations to the input data, followed by an attention bridge normalization, enhancing the connectivity between different parts of the model. diff --git a/docs/source/index.rst b/docs/source/index.rst index abb4c23b..7f04a9db 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,8 +38,8 @@ Contents :caption: API :maxdepth: 2 - onmt.rst - onmt.modules.rst - onmt.translation.rst - onmt.translate.translation_server.rst - onmt.inputters.rst + mammoth.rst + mammoth.modules.rst + mammoth.translation.rst + mammoth.translate.translation_server.rst + mammoth.inputters.rst diff --git a/docs/source/mammoth.inputters.rst b/docs/source/mammoth.inputters.rst new file mode 100644 index 00000000..b95aae67 --- /dev/null +++ b/docs/source/mammoth.inputters.rst @@ -0,0 +1,20 @@ +Data Loaders +================= + +Data Readers +------------- + +.. autoexception:: mammoth.inputters.datareader_base.MissingDependencyException + +.. autoclass:: mammoth.inputters.DataReaderBase + :members: + +.. autoclass:: mammoth.inputters.TextDataReader + :members: + + +Dataset +-------- + +.. autoclass:: mammoth.inputters.Dataset + :members: diff --git a/docs/source/mammoth.modules.rst b/docs/source/mammoth.modules.rst new file mode 100644 index 00000000..de33bfd5 --- /dev/null +++ b/docs/source/mammoth.modules.rst @@ -0,0 +1,109 @@ +Modules +============= + +Core Modules +------------ + +.. autoclass:: mammoth.modules.Embeddings + :members: + + +Encoders +--------- + +.. autoclass:: mammoth.encoders.EncoderBase + :members: + +.. autoclass:: mammoth.encoders.MeanEncoder + :members: + +.. autoclass:: mammoth.encoders.RNNEncoder + :members: + + +Decoders +--------- + + +.. autoclass:: mammoth.decoders.DecoderBase + :members: + +.. autoclass:: mammoth.decoders.decoder.RNNDecoderBase + :members: + +.. autoclass:: mammoth.decoders.StdRNNDecoder + :members: + +.. autoclass:: mammoth.decoders.InputFeedRNNDecoder + :members: + +Attention +---------- + +.. autoclass:: mammoth.modules.AverageAttention + :members: + +.. autoclass:: mammoth.modules.GlobalAttention + :members: + + + +Architecture: Transformer +---------------------------- + +.. autoclass:: mammoth.modules.PositionalEncoding + :members: + +.. autoclass:: mammoth.modules.position_ffn.PositionwiseFeedForward + :members: + +.. autoclass:: mammoth.encoders.TransformerEncoder + :members: + +.. autoclass:: mammoth.decoders.TransformerDecoder + :members: + +.. autoclass:: mammoth.modules.MultiHeadedAttention + :members: + :undoc-members: + + +Architecture: Conv2Conv +---------------------------- + +(These methods are from a user contribution +and have not been thoroughly tested.) + + +.. autoclass:: mammoth.encoders.CNNEncoder + :members: + + +.. autoclass:: mammoth.decoders.CNNDecoder + :members: + +.. autoclass:: mammoth.modules.ConvMultiStepAttention + :members: + +.. autoclass:: mammoth.modules.WeightNormConv2d + :members: + +Architecture: SRU +---------------------------- + +.. autoclass:: mammoth.models.sru.SRU + :members: + + +Copy Attention +-------------- + +.. autoclass:: mammoth.modules.CopyGenerator + :members: + + +Structured Attention +------------------------------------------- + +.. autoclass:: mammoth.modules.structured_attention.MatrixTree + :members: diff --git a/docs/source/mammoth.rst b/docs/source/mammoth.rst new file mode 100644 index 00000000..cd3d2a8f --- /dev/null +++ b/docs/source/mammoth.rst @@ -0,0 +1,32 @@ +Framework +================= + +Model +----- + +.. autoclass:: mammoth.models.NMTModel + :members: + +Trainer +------- + +.. autoclass:: mammoth.Trainer + :members: + + +.. autoclass:: mammoth.utils.Statistics + :members: + +Loss +---- + + +.. autoclass:: mammoth.utils.loss.LossComputeBase + :members: + + +Optimizer +--------- + +.. autoclass:: mammoth.utils.Optimizer + :members: diff --git a/docs/source/mammoth.translate.translation_server.rst b/docs/source/mammoth.translate.translation_server.rst new file mode 100644 index 00000000..0bc9dad7 --- /dev/null +++ b/docs/source/mammoth.translate.translation_server.rst @@ -0,0 +1,21 @@ +Server +====== + + +Models +------------- + +.. autoclass:: mammoth.translate.translation_server.ServerModel + :members: + + +Core Server +------------ + +.. autoexception:: mammoth.translate.translation_server.ServerModelError + +.. autoclass:: mammoth.translate.translation_server.Timer + :members: + +.. autoclass:: mammoth.translate.translation_server.TranslationServer + :members: diff --git a/docs/source/mammoth.translation.rst b/docs/source/mammoth.translation.rst new file mode 100644 index 00000000..6b075f96 --- /dev/null +++ b/docs/source/mammoth.translation.rst @@ -0,0 +1,39 @@ +Translation +================== + +Translations +------------- + +.. autoclass:: mammoth.translate.Translation + :members: + +Translator Class +----------------- + +.. autoclass:: mammoth.translate.Translator + :members: + +.. autoclass:: mammoth.translate.TranslationBuilder + :members: + + +Decoding Strategies +-------------------- +.. autoclass:: mammoth.translate.DecodeStrategy + :members: + +.. autoclass:: mammoth.translate.BeamSearch + :members: + +.. autofunction:: mammoth.translate.greedy_search.sample_with_temperature + +.. autoclass:: mammoth.translate.GreedySearch + :members: + +Scoring +-------- +.. autoclass:: mammoth.translate.penalties.PenaltyBuilder + :members: + +.. autoclass:: mammoth.translate.GNMTGlobalScorer + :members: diff --git a/docs/source/onmt.inputters.rst b/docs/source/onmt.inputters.rst deleted file mode 100644 index 99507e29..00000000 --- a/docs/source/onmt.inputters.rst +++ /dev/null @@ -1,20 +0,0 @@ -Data Loaders -================= - -Data Readers -------------- - -.. autoexception:: onmt.inputters.datareader_base.MissingDependencyException - -.. autoclass:: onmt.inputters.DataReaderBase - :members: - -.. autoclass:: onmt.inputters.TextDataReader - :members: - - -Dataset --------- - -.. autoclass:: onmt.inputters.Dataset - :members: diff --git a/docs/source/onmt.modules.rst b/docs/source/onmt.modules.rst deleted file mode 100644 index a3ef216e..00000000 --- a/docs/source/onmt.modules.rst +++ /dev/null @@ -1,109 +0,0 @@ -Modules -============= - -Core Modules ------------- - -.. autoclass:: onmt.modules.Embeddings - :members: - - -Encoders ---------- - -.. autoclass:: onmt.encoders.EncoderBase - :members: - -.. autoclass:: onmt.encoders.MeanEncoder - :members: - -.. autoclass:: onmt.encoders.RNNEncoder - :members: - - -Decoders ---------- - - -.. autoclass:: onmt.decoders.DecoderBase - :members: - -.. autoclass:: onmt.decoders.decoder.RNNDecoderBase - :members: - -.. autoclass:: onmt.decoders.StdRNNDecoder - :members: - -.. autoclass:: onmt.decoders.InputFeedRNNDecoder - :members: - -Attention ----------- - -.. autoclass:: onmt.modules.AverageAttention - :members: - -.. autoclass:: onmt.modules.GlobalAttention - :members: - - - -Architecture: Transformer ----------------------------- - -.. autoclass:: onmt.modules.PositionalEncoding - :members: - -.. autoclass:: onmt.modules.position_ffn.PositionwiseFeedForward - :members: - -.. autoclass:: onmt.encoders.TransformerEncoder - :members: - -.. autoclass:: onmt.decoders.TransformerDecoder - :members: - -.. autoclass:: onmt.modules.MultiHeadedAttention - :members: - :undoc-members: - - -Architecture: Conv2Conv ----------------------------- - -(These methods are from a user contribution -and have not been thoroughly tested.) - - -.. autoclass:: onmt.encoders.CNNEncoder - :members: - - -.. autoclass:: onmt.decoders.CNNDecoder - :members: - -.. autoclass:: onmt.modules.ConvMultiStepAttention - :members: - -.. autoclass:: onmt.modules.WeightNormConv2d - :members: - -Architecture: SRU ----------------------------- - -.. autoclass:: onmt.models.sru.SRU - :members: - - -Copy Attention --------------- - -.. autoclass:: onmt.modules.CopyGenerator - :members: - - -Structured Attention -------------------------------------------- - -.. autoclass:: onmt.modules.structured_attention.MatrixTree - :members: diff --git a/docs/source/onmt.rst b/docs/source/onmt.rst deleted file mode 100644 index 5ae056ce..00000000 --- a/docs/source/onmt.rst +++ /dev/null @@ -1,32 +0,0 @@ -Framework -================= - -Model ------ - -.. autoclass:: onmt.models.NMTModel - :members: - -Trainer -------- - -.. autoclass:: onmt.Trainer - :members: - - -.. autoclass:: onmt.utils.Statistics - :members: - -Loss ----- - - -.. autoclass:: onmt.utils.loss.LossComputeBase - :members: - - -Optimizer ---------- - -.. autoclass:: onmt.utils.Optimizer - :members: diff --git a/docs/source/onmt.translate.translation_server.rst b/docs/source/onmt.translate.translation_server.rst deleted file mode 100644 index 3426fade..00000000 --- a/docs/source/onmt.translate.translation_server.rst +++ /dev/null @@ -1,21 +0,0 @@ -Server -====== - - -Models -------------- - -.. autoclass:: onmt.translate.translation_server.ServerModel - :members: - - -Core Server ------------- - -.. autoexception:: onmt.translate.translation_server.ServerModelError - -.. autoclass:: onmt.translate.translation_server.Timer - :members: - -.. autoclass:: onmt.translate.translation_server.TranslationServer - :members: diff --git a/docs/source/onmt.translation.rst b/docs/source/onmt.translation.rst deleted file mode 100644 index bb6f5a5d..00000000 --- a/docs/source/onmt.translation.rst +++ /dev/null @@ -1,39 +0,0 @@ -Translation -================== - -Translations -------------- - -.. autoclass:: onmt.translate.Translation - :members: - -Translator Class ------------------ - -.. autoclass:: onmt.translate.Translator - :members: - -.. autoclass:: onmt.translate.TranslationBuilder - :members: - - -Decoding Strategies --------------------- -.. autoclass:: onmt.translate.DecodeStrategy - :members: - -.. autoclass:: onmt.translate.BeamSearch - :members: - -.. autofunction:: onmt.translate.greedy_search.sample_with_temperature - -.. autoclass:: onmt.translate.GreedySearch - :members: - -Scoring --------- -.. autoclass:: onmt.translate.penalties.PenaltyBuilder - :members: - -.. autoclass:: onmt.translate.GNMTGlobalScorer - :members: diff --git a/docs/source/options/build_vocab.rst b/docs/source/options/build_vocab.rst index 57fda68e..95bdc79b 100644 --- a/docs/source/options/build_vocab.rst +++ b/docs/source/options/build_vocab.rst @@ -2,7 +2,7 @@ Build Vocab =========== .. argparse:: - :filename: ../onmt/bin/build_vocab.py + :filename: ../mammoth/bin/build_vocab.py :func: _get_parser :prog: build_vocab.py diff --git a/docs/source/options/server.rst b/docs/source/options/server.rst index 63b2676f..b883d4fe 100644 --- a/docs/source/options/server.rst +++ b/docs/source/options/server.rst @@ -2,6 +2,6 @@ Server ========= .. argparse:: - :filename: ../onmt/bin/server.py + :filename: ../mammoth/bin/server.py :func: _get_parser :prog: server.py \ No newline at end of file diff --git a/docs/source/options/train.rst b/docs/source/options/train.rst index 67dc1cb2..066aa160 100644 --- a/docs/source/options/train.rst +++ b/docs/source/options/train.rst @@ -2,6 +2,6 @@ Train ===== .. argparse:: - :filename: ../onmt/bin/train.py + :filename: ../mammoth/bin/train.py :func: _get_parser :prog: train.py \ No newline at end of file diff --git a/docs/source/options/translate.rst b/docs/source/options/translate.rst index db0423a4..4b6244b7 100644 --- a/docs/source/options/translate.rst +++ b/docs/source/options/translate.rst @@ -2,6 +2,6 @@ Translate ========= .. argparse:: - :filename: ../onmt/bin/translate.py + :filename: ../mammoth/bin/translate.py :func: _get_parser :prog: translate.py \ No newline at end of file diff --git a/mammoth/__init__.py b/mammoth/__init__.py new file mode 100644 index 00000000..fd6ae773 --- /dev/null +++ b/mammoth/__init__.py @@ -0,0 +1,23 @@ +""" Main entry point of the Mammoth library """ +import mammoth.inputters +import mammoth.models +import mammoth.utils +import mammoth.modules +import mammoth.opts +from mammoth.trainer import Trainer +import sys +import mammoth.utils.optimizers + +mammoth.utils.optimizers.Optim = mammoth.utils.optimizers.Optimizer +sys.modules["mammoth.Optim"] = mammoth.utils.optimizers + +__all__ = [ + mammoth.inputters, + mammoth.models, + mammoth.utils, + mammoth.modules, + mammoth.opts, + "Trainer" +] + +__version__ = "2.2.0" diff --git a/onmt/bin/__init__.py b/mammoth/bin/__init__.py similarity index 100% rename from onmt/bin/__init__.py rename to mammoth/bin/__init__.py diff --git a/onmt/bin/average_models.py b/mammoth/bin/average_models.py similarity index 81% rename from onmt/bin/average_models.py rename to mammoth/bin/average_models.py index d9c09875..417b2c6c 100755 --- a/onmt/bin/average_models.py +++ b/mammoth/bin/average_models.py @@ -5,7 +5,7 @@ def average_models(model_files, fp32=False): vocab = None - opt = None + opts = None avg_model = None avg_generator = None @@ -21,7 +21,7 @@ def average_models(model_files, fp32=False): generator_weights[k] = v.float() if i == 0: - vocab, opt = m['vocab'], m['opt'] + vocab, opts = m['vocab'], m['opts'] avg_model = model_weights avg_generator = generator_weights else: @@ -31,7 +31,7 @@ def average_models(model_files, fp32=False): for (k, v) in avg_generator.items(): avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1) - final = {"vocab": vocab, "opt": opt, "optim": None, "generator": avg_generator, "model": avg_model} + final = {"vocab": vocab, "opts": opts, "optim": None, "generator": avg_generator, "model": avg_model} return final @@ -40,10 +40,10 @@ def main(): parser.add_argument("-models", "-m", nargs="+", required=True, help="List of models") parser.add_argument("-output", "-o", required=True, help="Output file") parser.add_argument("-fp32", "-f", action="store_true", help="Cast params to float32") - opt = parser.parse_args() + opts = parser.parse_args() - final = average_models(opt.models, opt.fp32) - torch.save(final, opt.output) + final = average_models(opts.models, opts.fp32) + torch.save(final, opts.output) if __name__ == "__main__": diff --git a/onmt/bin/build_vocab.py b/mammoth/bin/build_vocab.py similarity index 89% rename from onmt/bin/build_vocab.py rename to mammoth/bin/build_vocab.py index 77ba3bf3..65f408b1 100644 --- a/onmt/bin/build_vocab.py +++ b/mammoth/bin/build_vocab.py @@ -1,13 +1,13 @@ #!/usr/bin/env python """Get vocabulary coutings from transformed corpora samples.""" from collections import Counter, defaultdict -from onmt.utils.logging import init_logger -from onmt.utils.misc import set_random_seed, check_path -from onmt.utils.parse import ArgumentParser -from onmt.opts import dynamic_prepare_opts -from onmt.inputters import build_vocab_counts -from onmt.transforms import make_transforms, get_transforms_cls -from onmt.utils.distributed import TaskSpecs +from mammoth.utils.logging import init_logger +from mammoth.utils.misc import set_random_seed, check_path +from mammoth.utils.parse import ArgumentParser +from mammoth.opts import dynamic_prepare_opts +from mammoth.inputters import build_vocab_counts +from mammoth.transforms import make_transforms, get_transforms_cls +from mammoth.distributed import TaskSpecs def build_vocab_main(opts): @@ -32,8 +32,8 @@ def build_vocab_main(opts): src_counters_by_lang = defaultdict(Counter) tgt_counters_by_lang = defaultdict(Counter) - for corpus_id in opts.data: - lang_pair = opts.data[corpus_id]['src_tgt'] + for corpus_id in opts.tasks: + lang_pair = opts.tasks[corpus_id]['src_tgt'] src_lang, tgt_lang = lang_pair.split('-') task = TaskSpecs( node_rank=None, @@ -44,7 +44,7 @@ def build_vocab_main(opts): decoder_id=tgt_lang, corpus_id=f'{src_lang}-{tgt_lang}', weight=1, - corpus_opt=dict(), + corpus_opts=dict(), src_vocab=None, tgt_vocab=None, encoder_adapter_ids=None, diff --git a/onmt/bin/release_model.py b/mammoth/bin/release_model.py similarity index 75% rename from onmt/bin/release_model.py rename to mammoth/bin/release_model.py index 7adcf93c..354341da 100755 --- a/onmt/bin/release_model.py +++ b/mammoth/bin/release_model.py @@ -17,19 +17,19 @@ def main(): default=None, help="Quantization type for CT2 model.", ) - opt = parser.parse_args() + opts = parser.parse_args() - model = torch.load(opt.model, map_location=torch.device("cpu")) - if opt.format == "pytorch": + model = torch.load(opts.model, map_location=torch.device("cpu")) + if opts.format == "pytorch": model["optim"] = None - torch.save(model, opt.output) - elif opt.format == "ctranslate2": + torch.save(model, opts.output) + elif opts.format == "ctranslate2": import ctranslate2 if not hasattr(ctranslate2, "__version__"): raise RuntimeError("onmt_release_model script requires ctranslate2 >= 2.0.0") - converter = ctranslate2.converters.OpenNMTPyConverter(opt.model) - converter.convert(opt.output, force=True, quantization=opt.quantization) + converter = ctranslate2.converters.OpenNMTPyConverter(opts.model) + converter.convert(opts.output, force=True, quantization=opts.quantization) if __name__ == "__main__": diff --git a/onmt/bin/server.py b/mammoth/bin/server.py similarity index 97% rename from onmt/bin/server.py rename to mammoth/bin/server.py index 75abbaf4..7dae4712 100755 --- a/onmt/bin/server.py +++ b/mammoth/bin/server.py @@ -3,7 +3,7 @@ from flask import Flask, jsonify, request from waitress import serve -from onmt.translate import TranslationServer, ServerModelError +from mammoth.translate import TranslationServer, ServerModelError import logging from logging.handlers import RotatingFileHandler @@ -50,9 +50,9 @@ def clone_model(model_id): timeout = data['timeout'] del data['timeout'] - opt = data.get('opt', None) + opts = data.get('opts', None) try: - model_id, load_time = translation_server.clone_model(model_id, opt, timeout) + model_id, load_time = translation_server.clone_model(model_id, opts, timeout) except ServerModelError as e: out['status'] = STATUS_ERROR out['error'] = str(e) diff --git a/onmt/bin/train.py b/mammoth/bin/train.py similarity index 62% rename from onmt/bin/train.py rename to mammoth/bin/train.py index 2a376880..4c1d9072 100644 --- a/onmt/bin/train.py +++ b/mammoth/bin/train.py @@ -4,7 +4,7 @@ from functools import partial import os -from onmt.utils.distributed import ( +from mammoth.distributed import ( DeviceContext, DeviceContextEnum, ErrorHandler, @@ -13,79 +13,79 @@ batch_producer, consumer, ) -from onmt.utils.misc import set_random_seed -# from onmt.modules.embeddings import prepare_pretrained_embeddings -from onmt.utils.logging import init_logger, logger - -from onmt.models.model_saver import load_checkpoint -from onmt.train_single import main as single_main -from onmt.inputters import DynamicDatasetIter - -from onmt.utils.parse import ArgumentParser -from onmt.opts import train_opts -from onmt.inputters import get_vocab, DEFAULT_SPECIALS -from onmt.transforms import get_transforms_cls +from mammoth.utils.misc import set_random_seed +# from mammoth.modules.embeddings import prepare_pretrained_embeddings +from mammoth.utils.logging import init_logger, logger + +from mammoth.models.model_saver import load_checkpoint +from mammoth.train_single import main as single_main +from mammoth.inputters import DynamicDatasetIter + +from mammoth.utils.parse import ArgumentParser +from mammoth.opts import train_opts +from mammoth.inputters import get_vocab, DEFAULT_SPECIALS +from mammoth.transforms import get_transforms_cls from collections import OrderedDict -from onmt.constants import ModelTask +from mammoth.constants import ModelTask # Set sharing strategy manually instead of default based on the OS. torch.multiprocessing.set_sharing_strategy('file_system') -# def prepare_fields_transforms(opt): +# def prepare_fields_transforms(opts): # """Prepare or dump fields & transforms before training.""" -# transforms_cls = get_transforms_cls(opt._all_transform) -# specials = get_specials(opt, transforms_cls) +# transforms_cls = get_transforms_cls(opts._all_transform) +# specials = get_specials(opts, transforms_cls) # -# fields = build_dynamic_fields(opt, src_specials=specials['src'], tgt_specials=specials['tgt']) +# fields = build_dynamic_fields(opts, src_specials=specials['src'], tgt_specials=specials['tgt']) # # # maybe prepare pretrained embeddings, if any -# prepare_pretrained_embeddings(opt, fields) +# prepare_pretrained_embeddings(opts, fields) # -# if opt.dump_fields: -# save_fields(fields, opt.save_data, overwrite=opt.overwrite) -# if opt.dump_transforms or opt.n_sample != 0: -# transforms = make_transforms(opt, transforms_cls, fields) -# if opt.dump_transforms: -# save_transforms(transforms, opt.save_data, overwrite=opt.overwrite) -# if opt.n_sample != 0: +# if opts.dump_fields: +# save_fields(fields, opts.save_data, overwrite=opts.overwrite) +# if opts.dump_transforms or opts.n_sample != 0: +# transforms = make_transforms(opts, transforms_cls, fields) +# if opts.dump_transforms: +# save_transforms(transforms, opts.save_data, overwrite=opts.overwrite) +# if opts.n_sample != 0: # logger.warning( -# f"`-n_sample` != 0: Training will not be started. Stop after saving {opt.n_sample} samples/corpus." +# f"`-n_sample` != 0: Training will not be started. Stop after saving {opts.n_sample} samples/corpus." # ) -# save_transformed_sample(opt, transforms, n_sample=opt.n_sample) +# save_transformed_sample(opts, transforms, n_sample=opts.n_sample) # logger.info("Sample saved, please check it before restart training.") # sys.exit() # return fields, transforms_cls # TODO: reimplement save_transformed_sample -def _init_train(opt): +def _init_train(opts): """Common initilization stuff for all training process.""" - ArgumentParser.validate_prepare_opts(opt) + ArgumentParser.validate_prepare_opts(opts) - if opt.train_from: + if opts.train_from: # Load checkpoint if we resume from a previous training. - checkpoint = load_checkpoint(ckpt_path=opt.train_from) - # fields = load_fields(opt.save_data, checkpoint) - transforms_cls = get_transforms_cls(opt._all_transform) + checkpoint = load_checkpoint(ckpt_path=opts.train_from) + # fields = load_fields(opts.save_data, checkpoint) + transforms_cls = get_transforms_cls(opts._all_transform) if ( - hasattr(checkpoint["opt"], '_all_transform') - and len(opt._all_transform.symmetric_difference(checkpoint["opt"]._all_transform)) != 0 + hasattr(checkpoint["opts"], '_all_transform') + and len(opts._all_transform.symmetric_difference(checkpoint["opts"]._all_transform)) != 0 ): _msg = "configured transforms is different from checkpoint:" - new_transf = opt._all_transform.difference(checkpoint["opt"]._all_transform) - old_transf = checkpoint["opt"]._all_transform.difference(opt._all_transform) + new_transf = opts._all_transform.difference(checkpoint["opts"]._all_transform) + old_transf = checkpoint["opts"]._all_transform.difference(opts._all_transform) if len(new_transf) != 0: _msg += f" +{new_transf}" if len(old_transf) != 0: _msg += f" -{old_transf}." logger.warning(_msg) - if opt.update_vocab: + if opts.update_vocab: logger.info("Updating checkpoint vocabulary with new vocabulary") - # fields, transforms_cls = prepare_fields_transforms(opt) + # fields, transforms_cls = prepare_fields_transforms(opts) else: checkpoint = None - # fields, transforms_cls = prepare_fields_transforms(opt) + # fields, transforms_cls = prepare_fields_transforms(opts) # Report src and tgt vocab sizes # for side in ['src', 'tgt']: @@ -100,24 +100,24 @@ def _init_train(opt): return checkpoint, None, transforms_cls -# def init_train_prepare_fields_transforms(opt, vocab_path, side): +# def init_train_prepare_fields_transforms(opts, vocab_path, side): # """Prepare or dump fields & transforms before training.""" # -# fields = None # build_dynamic_fields_langspec(opt, vocab_path, side) -# transforms_cls = get_transforms_cls(opt._all_transform) -# # TODO: maybe prepare pretrained embeddings, if any, with `prepare_pretrained_embeddings(opt, fields)` +# fields = None # build_dynamic_fields_langspec(opts, vocab_path, side) +# transforms_cls = get_transforms_cls(opts._all_transform) +# # TODO: maybe prepare pretrained embeddings, if any, with `prepare_pretrained_embeddings(opts, fields)` # -# # if opt.dump_fields: -# # save_fields(fields, opt.save_data, overwrite=opt.overwrite) -# if opt.dump_transforms or opt.n_sample != 0: -# transforms = make_transforms(opt, transforms_cls, fields) -# if opt.dump_transforms: -# save_transforms(transforms, opt.save_data, overwrite=opt.overwrite) -# if opt.n_sample != 0: +# # if opts.dump_fields: +# # save_fields(fields, opts.save_data, overwrite=opts.overwrite) +# if opts.dump_transforms or opts.n_sample != 0: +# transforms = make_transforms(opts, transforms_cls, fields) +# if opts.dump_transforms: +# save_transforms(transforms, opts.save_data, overwrite=opts.overwrite) +# if opts.n_sample != 0: # logger.warning( -# f"`-n_sample` != 0: Training will not be started. Stop after saving {opt.n_sample} samples/corpus." +# f"`-n_sample` != 0: Training will not be started. Stop after saving {opts.n_sample} samples/corpus." # ) -# save_transformed_sample(opt, transforms, n_sample=opt.n_sample) +# save_transformed_sample(opts, transforms, n_sample=opts.n_sample) # logger.info("Sample saved, please check it before restart training.") # sys.exit() # @@ -127,7 +127,7 @@ def _init_train(opt): # return fields -def validate_slurm_node_opts(current_env, world_context, opt): +def validate_slurm_node_opts(current_env, world_context, opts): """If you are using slurm, confirm that opts match slurm environment variables""" slurm_n_nodes = int(current_env['SLURM_NNODES']) if slurm_n_nodes != world_context.n_nodes: @@ -136,35 +136,35 @@ def validate_slurm_node_opts(current_env, world_context, opt): f'but set n_nodes to {world_context.n_nodes} in the conf' ) slurm_node_id = int(current_env['SLURM_NODEID']) - if slurm_node_id != opt.node_rank: + if slurm_node_id != opts.node_rank: raise ValueError( f'Looks like you are running on slurm node {slurm_node_id}, ' - f'but set node_rank to {opt.node_rank} on the command line' + f'but set node_rank to {opts.node_rank} on the command line' ) -def train(opt): - init_logger(opt.log_file) - ArgumentParser.validate_train_opts(opt) - ArgumentParser.update_model_opts(opt) - ArgumentParser.validate_model_opts(opt) - ArgumentParser.validate_prepare_opts(opt) - set_random_seed(opt.seed, False) +def train(opts): + init_logger(opts.log_file) + ArgumentParser.validate_train_opts(opts) + ArgumentParser.update_model_opts(opts) + ArgumentParser.validate_model_opts(opts) + ArgumentParser.validate_prepare_opts(opts) + set_random_seed(opts.seed, False) # set PyTorch distributed related environment variables current_env = os.environ - current_env["WORLD_SIZE"] = str(opt.world_size) - world_context = WorldContext.from_opt(opt) + current_env["WORLD_SIZE"] = str(opts.world_size) + world_context = WorldContext.from_opts(opts) if 'SLURM_NNODES' in current_env: - validate_slurm_node_opts(current_env, world_context, opt) + validate_slurm_node_opts(current_env, world_context, opts) logger.info(f'Training on {world_context}') - opt.data_task = ModelTask.SEQ2SEQ + opts.data_task = ModelTask.SEQ2SEQ - transforms_cls = get_transforms_cls(opt._all_transform) + transforms_cls = get_transforms_cls(opts._all_transform) if transforms_cls: logger.info(f'All transforms: {transforms_cls}') - src_specials, tgt_specials = zip(*(cls.get_specials(opt) for cls in transforms_cls.values())) + src_specials, tgt_specials = zip(*(cls.get_specials(opts) for cls in transforms_cls.values())) all_specials = set(DEFAULT_SPECIALS) for special_group in src_specials + tgt_specials: all_specials = all_specials | special_group @@ -177,12 +177,12 @@ def train(opt): vocabs_dict = OrderedDict() # For creating fields, we use a task_queue_manager that doesn't filter by node and gpu - global_task_queue_manager = TaskQueueManager.from_opt(opt, world_context) + global_task_queue_manager = TaskQueueManager.from_opts(opts, world_context) - vocab_size = {'src': opt.src_vocab_size or None, 'tgt': opt.tgt_vocab_size or None} + vocab_size = {'src': opts.src_vocab_size or None, 'tgt': opts.tgt_vocab_size or None} for side in ('src', 'tgt'): for lang in global_task_queue_manager.get_langs(side): - vocab_path = opt.__getattribute__(f'{side}_vocab')[lang] + vocab_path = opts.__getattribute__(f'{side}_vocab')[lang] # FIXME: for now, all specials are passed to all vocabs, this could be finer-grained vocabs_dict[(side, lang)] = get_vocab(vocab_path, lang, vocab_size[side], specials=all_specials) # for key, val in fields_dict: @@ -193,14 +193,14 @@ def train(opt): logger.debug(f"[{os.getpid()}] Initializing process group with: {current_env}") if world_context.context == DeviceContextEnum.MULTI_GPU: - current_env["MASTER_ADDR"] = opt.master_ip - current_env["MASTER_PORT"] = str(opt.master_port) - node_rank = opt.node_rank + current_env["MASTER_ADDR"] = opts.master_ip + current_env["MASTER_PORT"] = str(opts.master_port) + node_rank = opts.node_rank queues = [] semaphores = [] mp = torch.multiprocessing.get_context('spawn') - logger.info("world_size = {}, queue_size = {}".format(opt.world_size, opt.queue_size)) + logger.info("world_size = {}, queue_size = {}".format(opts.world_size, opts.queue_size)) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) @@ -217,21 +217,21 @@ def train(opt): task_queue_manager = global_task_queue_manager.global_to_local( node_rank=node_rank, local_rank=local_rank, - opt=opt + opts=opts ) # store rank in env (FIXME: is this obsolete?) current_env["RANK"] = str(device_context.global_rank) current_env["LOCAL_RANK"] = str(device_context.local_rank) - q = mp.Queue(opt.queue_size) - semaphore = mp.Semaphore(opt.queue_size) + q = mp.Queue(opts.queue_size) + semaphore = mp.Semaphore(opts.queue_size) queues.append(q) semaphores.append(semaphore) procs.append( mp.Process( target=consumer, - args=(train_process, opt, device_context, error_queue, q, semaphore, task_queue_manager), + args=(train_process, opts, device_context, error_queue, q, semaphore, task_queue_manager), daemon=True, ) ) @@ -244,12 +244,12 @@ def train(opt): task_queue_manager=task_queue_manager, transforms_cls=transforms_cls, vocabs_dict=vocabs_dict, - opts=opt, + opts=opts, is_train=True, ) producer = mp.Process( - target=batch_producer, args=(train_iter, q, semaphore, opt, local_rank), daemon=True + target=batch_producer, args=(train_iter, q, semaphore, opts, local_rank), daemon=True ) producers.append(producer) producers[local_rank].start() @@ -272,9 +272,9 @@ def train(opt): task_queue_manager = global_task_queue_manager.global_to_local( node_rank=0, local_rank=0, - opt=opt + opts=opts ) - train_process(opt, device_context=device_context, task_queue_manager=task_queue_manager) + train_process(opts, device_context=device_context, task_queue_manager=task_queue_manager) def _get_parser(): @@ -286,8 +286,8 @@ def _get_parser(): def main(): parser = _get_parser() - opt, unknown = parser.parse_known_args() - train(opt) + opts, unknown = parser.parse_known_args() + train(opts) if __name__ == "__main__": diff --git a/mammoth/bin/translate.py b/mammoth/bin/translate.py new file mode 100644 index 00000000..8c86bbf0 --- /dev/null +++ b/mammoth/bin/translate.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from mammoth.utils.logging import init_logger +from mammoth.utils.misc import split_corpus +from mammoth.translate.translator import build_translator +# from mammoth.inputters.text_dataset import InferenceDataReader +from mammoth.transforms import get_transforms_cls, make_transforms, TransformPipe + +import mammoth.opts as opts +from mammoth.distributed import TaskSpecs +from mammoth.utils.parse import ArgumentParser + + +def translate(opts): + ArgumentParser.validate_translate_opts(opts) + ArgumentParser._get_all_transform_translate(opts) + ArgumentParser._validate_transforms_opts(opts) + ArgumentParser.validate_translate_opts_dynamic(opts) + logger = init_logger(opts.log_file) + + encoder_adapter_ids = set() + for layer_stack_idx, stack in enumerate(opts.stack['encoder']): + if 'adapters' in stack: + for group_id, sub_id in stack['adapters']: + encoder_adapter_ids.add((layer_stack_idx, group_id, sub_id)) + decoder_adapter_ids = set() + for layer_stack_idx, stack in enumerate(opts.stack['decoder']): + if 'adapters' in stack: + for group_id, sub_id in stack['adapters']: + decoder_adapter_ids.add((layer_stack_idx, group_id, sub_id)) + + logger.info( + 'It is ok that src_vocab and tgt_vocab are None here. ' + 'The vocabs are separately loaded in model_builder.' + ) + task = TaskSpecs( + node_rank=None, + local_rank=None, + src_lang=opts.src_lang, + tgt_lang=opts.tgt_lang, + encoder_id=[stack['id'] for stack in opts.stack['encoder']], + decoder_id=[stack['id'] for stack in opts.stack['decoder']], + corpus_id='trans', + weight=1, + corpus_opts=dict(), + src_vocab=None, + tgt_vocab=None, + encoder_adapter_ids=encoder_adapter_ids, + decoder_adapter_ids=decoder_adapter_ids, + ) + + translator = build_translator(opts, task, logger=logger, report_score=True) + + # data_reader = InferenceDataReader(opts.src, opts.tgt, opts.src_feats) + src_shards = split_corpus(opts.src, opts.shard_size) + tgt_shards = split_corpus(opts.tgt, opts.shard_size) + features_shards = [] + features_names = [] + for feat_name, feat_path in opts.src_feats.items(): + features_shards.append(split_corpus(feat_path, opts.shard_size)) + features_names.append(feat_name) + shard_pairs = zip(src_shards, tgt_shards, *features_shards) + + # Build transforms + transforms_cls = get_transforms_cls(opts._all_transform) + transforms = make_transforms(opts, transforms_cls, translator.vocabs, task=task) + data_transform = [ + transforms[name] for name in opts.transforms if name in transforms + ] + transform = TransformPipe.build_from(data_transform) + + for i, (src_shard, tgt_shard, *feats_shard) in enumerate(shard_pairs): + logger.info("Translating shard %d." % i) + translator.translate_dynamic( + src=src_shard, + transform=transform, + # src_feats=feats_shard, # TODO: put me back in + tgt=tgt_shard, + batch_size=opts.batch_size, + batch_type=opts.batch_type, + attn_debug=opts.attn_debug, + align_debug=opts.align_debug + ) + + +def _get_parser(): + parser = ArgumentParser(description='translate.py') + + opts.config_opts(parser) + opts.translate_opts(parser, dynamic=True) + opts.build_bilingual_model(parser) + return parser + + +def main(): + parser = _get_parser() + + opts = parser.parse_args() + translate(opts) + + +if __name__ == "__main__": + main() diff --git a/onmt/constants.py b/mammoth/constants.py similarity index 100% rename from onmt/constants.py rename to mammoth/constants.py diff --git a/mammoth/distributed/__init__.py b/mammoth/distributed/__init__.py new file mode 100644 index 00000000..5a032c0c --- /dev/null +++ b/mammoth/distributed/__init__.py @@ -0,0 +1,32 @@ +"""Module defining distributed communications utilities.""" +from mammoth.distributed.communication import ( + all_gather_list, + batch_producer, + consumer, + broadcast_tensors, + only_ready_reduce_and_rescale_grads, + ErrorHandler, +) +from mammoth.distributed.contexts import DeviceContext, WorldContext, DeviceContextEnum +from mammoth.distributed.tasks import ( + TaskSpecs, + TaskQueueManager, + DatasetMetadata, + TASK_DISTRIBUTION_STRATEGIES, +) + +__all__ = [ + "all_gather_list", + "batch_producer", + "broadcast_tensors", + "consumer", + "only_ready_reduce_and_rescale_grads", + "ErrorHandler", + "DeviceContext", + "WorldContext", + "DeviceContextEnum", + "TASK_DISTRIBUTION_STRATEGIES", + "DatasetMetadata", + "TaskQueueManager", + "TaskSpecs", +] diff --git a/mammoth/distributed/communication.py b/mammoth/distributed/communication.py new file mode 100644 index 00000000..687da4b9 --- /dev/null +++ b/mammoth/distributed/communication.py @@ -0,0 +1,282 @@ +"""Module defining low-level comunication utilities (initialization, brodcasting, etc.)""" +import math +import os +import pickle +import signal + +import torch +import torch.distributed + +from mammoth.utils.logging import init_logger, logger +from mammoth.utils.misc import set_random_seed + + +def multi_init(opts, global_rank): + dist_init_method = 'tcp://{master_ip}:{master_port}'.format(master_ip=opts.master_ip, master_port=opts.master_port) + + dist_world_size = opts.world_size + torch.distributed.init_process_group( + backend=opts.gpu_backend, + init_method=dist_init_method, + rank=global_rank, + world_size=dist_world_size, + ) + + gpu_rank = torch.distributed.get_rank() + + return gpu_rank + + +def broadcast_tensors(tensors, src=0, group=None): + for t in tensors: + if group is None: + torch.distributed.broadcast(t, src) + else: + torch.distributed.broadcast(t, src, group=group) + + +def only_ready_reduce_and_rescale_grads(named_parameters, group=None): + """ + Gradient synch tolerant to missing grads. + + Missing grads occur when some parameters are not trained between two + gradient synchs, e.g. the embeddings of a low-resource language with low + sampling weight. + + The algorithm first uses the 'has_grad' attribute set by the forward hook + 'has_grad_hook'. This hook ensures that all parameters of the modules + selected for use during the current training computation have 'has_grad' + set to True. This gives the list of parameters that have been trained on + this device ("ready"). + + A bit mask covering the parameters that are ready on this device is + communicated to the other devices in the group. The bit masks are reduced + using summation. The sum gives the number of real gradients for that + parameter, and can be used for normalization. + + If a parameter is ready on any device, all devices communicate a value. + Devices on which the parameter is ready communicate the actual gradient, + while devices on which it is not ready communicate a dummy zero tensor + instead. The sum computed previously is used for normalization. + + Args: + named_parameters: tuples of (str, Parameter) defining the parameters to consider + group: torch.distributed communication group + """ + # Set missing gradients to zero, keeping track of true gradients + require_grad = [(name, p) for (name, p) in named_parameters if p.requires_grad] + if not require_grad: + # Exit early if the component has no parameters that require a gradient + return + device = require_grad[0][1].device + ready_list = [] + for name, p in require_grad: + if hasattr(p, 'has_grad') and p.has_grad: + ready_list.append(1.0) + else: + ready_list.append(0.0) + if p.grad is None: + p.grad = torch.zeros_like(p) + + # Communicate the ready bits, and reduce them using summation. + # This gives the number of non-dummy gradients participating, for normalization + ready_t = torch.tensor(ready_list).to(device) + if group is None: + torch.distributed.all_reduce(ready_t) + else: + torch.distributed.all_reduce(ready_t, group=group) + rescale_denoms = ready_t # after reduction + + # Omit if all nodes sent a zero ready bit + denoms_mask = (rescale_denoms > 0).cpu() + params_with_grad = [p for ((name, p), m) in zip(require_grad, denoms_mask) if m] + grads = [p.grad.data for p in params_with_grad] + rescale_denoms = [denom for (denom, m) in zip(rescale_denoms, denoms_mask) if m] + assert len(grads) == len(rescale_denoms) + if len(grads) == 0: + return + + # If not, then set has_grad also on devices that did not train the parameter themselves. + # They now have a grad that they received from the other devices. + for name, p in require_grad: + p.has_grad = True + + # All devices communicate either a real gradient or a dummy zeros of the same size + # Can not use rescale_denom, as each grad may have its own denominator + all_reduce_and_rescale_tensors(grads, rescale_denom=1, group=group) + + # Normalize using the previously computed values + for grad, denom in zip(grads, rescale_denoms): + if denom > 1: + grad.div_(denom) + # Note: p.has_grad is reused in the optimizer to prevent the untrained components from being stepped + + +def all_reduce_and_rescale_tensors(tensors, rescale_denom, group=None, buffer_size=10485760): + """ + All-reduce and rescale tensors in chunks of the specified size. + + Args: + tensors: list of Tensors to all-reduce + rescale_denom: denominator for rescaling summed Tensors + buffer_size: all-reduce chunk size in bytes + """ + # buffer size in bytes, determine equiv. # of elements based on data type + buffer_t = tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_() + buffer = [] + + def all_reduce_buffer(): + # copy tensors into buffer_t + offset = 0 + for t in buffer: + numel = t.numel() + buffer_t[offset:offset + numel].copy_(t.view(-1)) + offset += numel + + # all-reduce and rescale + if group is None: + torch.distributed.all_reduce(buffer_t[:offset]) + else: + torch.distributed.all_reduce(buffer_t[:offset], group=group) + buffer_t.div_(rescale_denom) + + # copy all-reduced buffer back into tensors + offset = 0 + for t in buffer: + numel = t.numel() + t.view(-1).copy_(buffer_t[offset:offset + numel]) + offset += numel + + filled = 0 + for t in tensors: + sz = t.numel() * t.element_size() + if sz > buffer_size: + # tensor is bigger than buffer, all-reduce and rescale directly + if group is None: + torch.distributed.all_reduce(t) + else: + torch.distributed.all_reduce(t, group=group) + t.div_(rescale_denom) + elif filled + sz > buffer_size: + # buffer is full, all-reduce and replace buffer with grad + all_reduce_buffer() + buffer = [t] + filled = sz + else: + # add tensor to buffer + buffer.append(t) + filled += sz + + if len(buffer) > 0: + all_reduce_buffer() + + +def all_gather_list(data, max_size=4096): + """Gathers arbitrary data from all nodes into a list.""" + world_size = torch.distributed.get_world_size() + if not hasattr(all_gather_list, '_in_buffer') or max_size != all_gather_list._in_buffer.size(): + all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) + all_gather_list._out_buffers = [torch.cuda.ByteTensor(max_size) for i in range(world_size)] + in_buffer = all_gather_list._in_buffer + out_buffers = all_gather_list._out_buffers + + enc = pickle.dumps(data) + enc_size = len(enc) + if enc_size + 2 > max_size: + raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2)) + assert max_size < 255 * 256 + in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k + in_buffer[1] = enc_size % 255 + in_buffer[2:enc_size + 2] = torch.ByteTensor(list(enc)) + + torch.distributed.all_gather(out_buffers, in_buffer.cuda()) + + results = [] + for i in range(world_size): + out_buffer = out_buffers[i] + size = (255 * out_buffer[0].item()) + out_buffer[1].item() + + bytes_list = bytes(out_buffer[2:size + 2].tolist()) + result = pickle.loads(bytes_list) + results.append(result) + return results + + +class ErrorHandler(object): + """A class that listens for exceptions in children processes and propagates + the tracebacks to the parent process.""" + + def __init__(self, error_queue): + """init error handler""" + import signal + import threading + + self.error_queue = error_queue + self.children_pids = [] + self.error_thread = threading.Thread(target=self.error_listener, daemon=True) + self.error_thread.start() + signal.signal(signal.SIGUSR1, self.signal_handler) + + def add_child(self, pid): + """error handler""" + self.children_pids.append(pid) + + def error_listener(self): + """error listener""" + (rank, original_trace) = self.error_queue.get() + self.error_queue.put((rank, original_trace)) + os.kill(os.getpid(), signal.SIGUSR1) + + def signal_handler(self, signalnum, stackframe): + """signal handler""" + for pid in self.children_pids: + os.kill(pid, signal.SIGINT) # kill children processes + (rank, original_trace) = self.error_queue.get() + msg = """\n\n-- Tracebacks above this line can probably + be ignored --\n\n""" + msg += original_trace + raise Exception(msg) + + +def batch_producer(generator_to_serve, queue, semaphore, opts, device_id): + """Produce batches to `queues` from `generator_to_serve`.""" + log_level = "INFO" if opts.verbose or device_id == 0 else "WARNING" + init_logger(opts.log_file, log_level=log_level) + set_random_seed(opts.seed, False) + logger.info("BATCH PRODUCER") + logger.info(generator_to_serve) + + for batch, metadata, communication_batch_id in generator_to_serve: + semaphore.acquire() + # Move batch to correspond device_id when consumer iterate + # hack to dodge unpicklable `dict_keys` + # batch.fields = list(batch.fields) + queue.put((batch, metadata, communication_batch_id)) + + +def consumer(process_fn, opts, device_context, error_queue, batch_queue, semaphore, task_queue_manager): + """Run `process_fn` on `device_id` with data from `batch_queue`.""" + try: + logger.info( + f'global_rank {device_context.global_rank} ' + f'node_rank {device_context.node_rank} ' + f'local_rank {device_context.local_rank}' + ) + logger.info(f'opts.gpu_ranks {opts.gpu_ranks}') + multi_init(opts, device_context.global_rank) + # error_queue not passed (is this intentional?) + process_fn( + opts, + device_context=device_context, + batch_queue=batch_queue, + semaphore=semaphore, + task_queue_manager=task_queue_manager, + ) + + except KeyboardInterrupt: + pass # killed by parent, do nothing + except Exception: + # propagate exception to parent process, keeping original traceback + import traceback + + error_queue.put((opts.gpu_ranks[device_context.node_rank], traceback.format_exc())) diff --git a/mammoth/distributed/contexts.py b/mammoth/distributed/contexts.py new file mode 100644 index 00000000..8a8e4241 --- /dev/null +++ b/mammoth/distributed/contexts.py @@ -0,0 +1,105 @@ +from dataclasses import dataclass +from enum import Enum + + +class DeviceContextEnum(Enum): + CPU = 1 + SINGLE_GPU = 2 + MULTI_GPU = 3 + + +@dataclass +class WorldContext: + context: DeviceContextEnum + # Size of the world: total number of nodes, gpus on each node + n_nodes: int + gpus_per_node: int + + @property + def world_size(self): + """Total number of training GPUs""" + return self.n_nodes * self.gpus_per_node + + def is_distributed(self): + """When training is distributed over several devices, + multiprocessing is used to communicate gradients""" + return self.context == DeviceContextEnum.MULTI_GPU + + def is_gpu(self): + """Data tensors must be moved to the GPU for compute""" + return self.context != DeviceContextEnum.CPU + + def is_master(self): + """For code that should only run in one process: + - saving fully shared modules from one device only + - avoiding log spam when all devices would log the same result + """ + return not self.is_distributed() or self.global_rank == 0 + + def global_to_local(self, node_rank, local_rank): + assert node_rank is not None + assert local_rank is not None + return DeviceContext( + context=self.context, + n_nodes=self.n_nodes, + gpus_per_node=self.gpus_per_node, + node_rank=node_rank, + local_rank=local_rank, + ) + + @classmethod + def from_opts(cls, opts): + gpus_per_node = len(opts.gpu_ranks) + world_size = int(opts.world_size) if gpus_per_node > 0 else 0 + multinode = gpus_per_node != world_size + if world_size <= 0: + # setting a non-positive world size means use CPU + device_context_enum = DeviceContextEnum.CPU + if opts.n_nodes != 1: + raise ValueError('CPU training is only possible on a single node') + elif world_size == 1: + # world size 1 uses GPU, but is not distributed + device_context_enum = DeviceContextEnum.SINGLE_GPU + if opts.n_nodes != 1: + raise ValueError( + f'Invalid single-gpu node configuration: ' + f'n_nodes {opts.n_nodes} gpus_per_node {gpus_per_node} world_size {world_size}' + ) + else: + # world size > 1 + if multinode and opts.n_nodes == 1: + raise ValueError( + f'Invalid multi-node configuration: ' + f'n_nodes {opts.n_nodes} gpus_per_node {gpus_per_node} world_size {world_size}' + ) + device_context_enum = DeviceContextEnum.MULTI_GPU + world_context = WorldContext(context=device_context_enum, n_nodes=opts.n_nodes, gpus_per_node=gpus_per_node) + return world_context + + +@dataclass +class DeviceContext(WorldContext): + # Our place in the world + node_rank: int + local_rank: int + + @property + def global_rank(self) -> int: + return self.gpus_per_node * self.node_rank + self.local_rank + + @property + def id(self) -> str: + if self.is_gpu(): + return f'GPU {self.node_rank}:{self.local_rank}' + else: + return 'CPU' + + def validate(self, world_context): + # check that this DeviceContext is consistent with given WorldContext + assert self.context == world_context.context + assert self.n_nodes == world_context.n_nodes + assert self.gpus_per_node == world_context.gpus_per_node + # check that ranks are within the specified size of the world + assert 0 <= self.node_rank < self.n_nodes + if self.is_gpu(): + assert 0 <= self.local_rank < self.gpus_per_node diff --git a/onmt/utils/distributed.py b/mammoth/distributed/tasks.py similarity index 58% rename from onmt/utils/distributed.py rename to mammoth/distributed/tasks.py index bb9f6263..d6023583 100644 --- a/onmt/utils/distributed.py +++ b/mammoth/distributed/tasks.py @@ -1,409 +1,31 @@ -""" Pytorch Distributed utils - This piece of code was heavily inspired by the equivalent of Fairseq-py - https://github.com/pytorch/fairseq -""" -import math -import numpy as np -import os -import pickle -import signal -import torch.distributed - +"""sub-module defining tasks, task specifications and task management objects.""" from abc import ABC, abstractmethod from argparse import Namespace from collections import OrderedDict, namedtuple, Counter from dataclasses import dataclass -from enum import Enum from itertools import cycle, islice from pprint import pformat from typing import Any, Optional, List -from onmt.utils.logging import init_logger, logger -from onmt.utils.misc import set_random_seed - - -class DeviceContextEnum(Enum): - CPU = 1 - SINGLE_GPU = 2 - MULTI_GPU = 3 - - -@dataclass -class WorldContext: - context: DeviceContextEnum - # Size of the world: total number of nodes, gpus on each node - n_nodes: int - gpus_per_node: int - - @property - def world_size(self): - """Total number of training GPUs""" - return self.n_nodes * self.gpus_per_node - - def is_distributed(self): - """When training is distributed over several devices, - multiprocessing is used to communicate gradients""" - return self.context == DeviceContextEnum.MULTI_GPU - - def is_gpu(self): - """Data tensors must be moved to the GPU for compute""" - return self.context != DeviceContextEnum.CPU - - def is_master(self): - """For code that should only run in one process: - - saving fully shared modules from one device only - - avoiding log spam when all devices would log the same result - """ - return not self.is_distributed() or self.global_rank == 0 - - def global_to_local(self, node_rank, local_rank): - assert node_rank is not None - assert local_rank is not None - return DeviceContext( - context=self.context, - n_nodes=self.n_nodes, - gpus_per_node=self.gpus_per_node, - node_rank=node_rank, - local_rank=local_rank, - ) - - @classmethod - def from_opt(cls, opt): - gpus_per_node = len(opt.gpu_ranks) - world_size = int(opt.world_size) if gpus_per_node > 0 else 0 - multinode = gpus_per_node != world_size - if world_size <= 0: - # setting a non-positive world size means use CPU - device_context_enum = DeviceContextEnum.CPU - if opt.n_nodes != 1: - raise ValueError('CPU training is only possible on a single node') - elif world_size == 1: - # world size 1 uses GPU, but is not distributed - device_context_enum = DeviceContextEnum.SINGLE_GPU - if opt.n_nodes != 1: - raise ValueError( - f'Invalid single-gpu node configuration: ' - f'n_nodes {opt.n_nodes} gpus_per_node {gpus_per_node} world_size {world_size}' - ) - else: - # world size > 1 - if multinode and opt.n_nodes == 1: - raise ValueError( - f'Invalid multi-node configuration: ' - f'n_nodes {opt.n_nodes} gpus_per_node {gpus_per_node} world_size {world_size}' - ) - device_context_enum = DeviceContextEnum.MULTI_GPU - world_context = WorldContext(context=device_context_enum, n_nodes=opt.n_nodes, gpus_per_node=gpus_per_node) - return world_context - - -@dataclass -class DeviceContext(WorldContext): - # Our place in the world - node_rank: int - local_rank: int - - @property - def global_rank(self) -> int: - return self.gpus_per_node * self.node_rank + self.local_rank - - @property - def id(self) -> str: - if self.is_gpu(): - return f'GPU {self.node_rank}:{self.local_rank}' - else: - return 'CPU' - - def validate(self, world_context): - # check that this DeviceContext is consistent with given WorldContext - assert self.context == world_context.context - assert self.n_nodes == world_context.n_nodes - assert self.gpus_per_node == world_context.gpus_per_node - # check that ranks are within the specified size of the world - assert 0 <= self.node_rank < self.n_nodes - if self.is_gpu(): - assert 0 <= self.local_rank < self.gpus_per_node - - -def multi_init(opt, global_rank): - dist_init_method = 'tcp://{master_ip}:{master_port}'.format(master_ip=opt.master_ip, master_port=opt.master_port) - - dist_world_size = opt.world_size - torch.distributed.init_process_group( - backend=opt.gpu_backend, - init_method=dist_init_method, - rank=global_rank, - world_size=dist_world_size, - ) - - gpu_rank = torch.distributed.get_rank() - - return gpu_rank - +import numpy as np +import torch +import torch.distributed -def broadcast_tensors(tensors, src=0, group=None): - for t in tensors: - if group is None: - torch.distributed.broadcast(t, src) - else: - torch.distributed.broadcast(t, src, group=group) +from mammoth.distributed.contexts import DeviceContext, WorldContext +from mammoth.utils.logging import logger -def only_ready_reduce_and_rescale_grads(named_parameters, group=None): - """ - Gradient synch tolerant to missing grads. - - Missing grads occur when some parameters are not trained between two - gradient synchs, e.g. the embeddings of a low-resource language with low - sampling weight. - - The algorithm first uses the 'has_grad' attribute set by the forward hook - 'has_grad_hook'. This hook ensures that all parameters of the modules - selected for use during the current training computation have 'has_grad' - set to True. This gives the list of parameters that have been trained on - this device ("ready"). - - A bit mask covering the parameters that are ready on this device is - communicated to the other devices in the group. The bit masks are reduced - using summation. The sum gives the number of real gradients for that - parameter, and can be used for normalization. - - If a parameter is ready on any device, all devices communicate a value. - Devices on which the parameter is ready communicate the actual gradient, - while devices on which it is not ready communicate a dummy zero tensor - instead. The sum computed previously is used for normalization. - - Args: - named_parameters: tuples of (str, Parameter) defining the parameters to consider - group: torch.distributed communication group - """ - # Set missing gradients to zero, keeping track of true gradients - require_grad = [(name, p) for (name, p) in named_parameters if p.requires_grad] - if not require_grad: - # Exit early if the component has no parameters that require a gradient - return - device = require_grad[0][1].device - ready_list = [] - for name, p in require_grad: - if hasattr(p, 'has_grad') and p.has_grad: - ready_list.append(1.0) - else: - ready_list.append(0.0) - if p.grad is None: - p.grad = torch.zeros_like(p) - - # Communicate the ready bits, and reduce them using summation. - # This gives the number of non-dummy gradients participating, for normalization - ready_t = torch.tensor(ready_list).to(device) - if group is None: - torch.distributed.all_reduce(ready_t) - else: - torch.distributed.all_reduce(ready_t, group=group) - rescale_denoms = ready_t # after reduction - - # Omit if all nodes sent a zero ready bit - denoms_mask = (rescale_denoms > 0).cpu() - params_with_grad = [p for ((name, p), m) in zip(require_grad, denoms_mask) if m] - grads = [p.grad.data for p in params_with_grad] - rescale_denoms = [denom for (denom, m) in zip(rescale_denoms, denoms_mask) if m] - assert len(grads) == len(rescale_denoms) - if len(grads) == 0: - return - - # If not, then set has_grad also on devices that did not train the parameter themselves. - # They now have a grad that they received from the other devices. - for name, p in require_grad: - p.has_grad = True - - # All devices communicate either a real gradient or a dummy zeros of the same size - # Can not use rescale_denom, as each grad may have its own denominator - all_reduce_and_rescale_tensors(grads, rescale_denom=1, group=group) - - # Normalize using the previously computed values - for grad, denom in zip(grads, rescale_denoms): - if denom > 1: - grad.div_(denom) - # Note: p.has_grad is reused in the optimizer to prevent the untrained components from being stepped - - -def all_reduce_and_rescale_tensors(tensors, rescale_denom, group=None, buffer_size=10485760): +class TaskDistributionStrategy(ABC): """ - All-reduce and rescale tensors in chunks of the specified size. - - Args: - tensors: list of Tensors to all-reduce - rescale_denom: denominator for rescaling summed Tensors - buffer_size: all-reduce chunk size in bytes + An abstract task distribution strategy, controls which task will be scheduled next. """ - # buffer size in bytes, determine equiv. # of elements based on data type - buffer_t = tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_() - buffer = [] - - def all_reduce_buffer(): - # copy tensors into buffer_t - offset = 0 - for t in buffer: - numel = t.numel() - buffer_t[offset:offset + numel].copy_(t.view(-1)) - offset += numel - - # all-reduce and rescale - if group is None: - torch.distributed.all_reduce(buffer_t[:offset]) - else: - torch.distributed.all_reduce(buffer_t[:offset], group=group) - buffer_t.div_(rescale_denom) - - # copy all-reduced buffer back into tensors - offset = 0 - for t in buffer: - numel = t.numel() - t.view(-1).copy_(buffer_t[offset:offset + numel]) - offset += numel - - filled = 0 - for t in tensors: - sz = t.numel() * t.element_size() - if sz > buffer_size: - # tensor is bigger than buffer, all-reduce and rescale directly - if group is None: - torch.distributed.all_reduce(t) - else: - torch.distributed.all_reduce(t, group=group) - t.div_(rescale_denom) - elif filled + sz > buffer_size: - # buffer is full, all-reduce and replace buffer with grad - all_reduce_buffer() - buffer = [t] - filled = sz - else: - # add tensor to buffer - buffer.append(t) - filled += sz - - if len(buffer) > 0: - all_reduce_buffer() - - -def all_gather_list(data, max_size=4096): - """Gathers arbitrary data from all nodes into a list.""" - world_size = torch.distributed.get_world_size() - if not hasattr(all_gather_list, '_in_buffer') or max_size != all_gather_list._in_buffer.size(): - all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) - all_gather_list._out_buffers = [torch.cuda.ByteTensor(max_size) for i in range(world_size)] - in_buffer = all_gather_list._in_buffer - out_buffers = all_gather_list._out_buffers - - enc = pickle.dumps(data) - enc_size = len(enc) - if enc_size + 2 > max_size: - raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2)) - assert max_size < 255 * 256 - in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k - in_buffer[1] = enc_size % 255 - in_buffer[2:enc_size + 2] = torch.ByteTensor(list(enc)) - - torch.distributed.all_gather(out_buffers, in_buffer.cuda()) - - results = [] - for i in range(world_size): - out_buffer = out_buffers[i] - size = (255 * out_buffer[0].item()) + out_buffer[1].item() - - bytes_list = bytes(out_buffer[2:size + 2].tolist()) - result = pickle.loads(bytes_list) - results.append(result) - return results - - -class ErrorHandler(object): - """A class that listens for exceptions in children processes and propagates - the tracebacks to the parent process.""" - - def __init__(self, error_queue): - """init error handler""" - import signal - import threading - - self.error_queue = error_queue - self.children_pids = [] - self.error_thread = threading.Thread(target=self.error_listener, daemon=True) - self.error_thread.start() - signal.signal(signal.SIGUSR1, self.signal_handler) - - def add_child(self, pid): - """error handler""" - self.children_pids.append(pid) - - def error_listener(self): - """error listener""" - (rank, original_trace) = self.error_queue.get() - self.error_queue.put((rank, original_trace)) - os.kill(os.getpid(), signal.SIGUSR1) - - def signal_handler(self, signalnum, stackframe): - """signal handler""" - for pid in self.children_pids: - os.kill(pid, signal.SIGINT) # kill children processes - (rank, original_trace) = self.error_queue.get() - msg = """\n\n-- Tracebacks above this line can probably - be ignored --\n\n""" - msg += original_trace - raise Exception(msg) - - -def batch_producer(generator_to_serve, queue, semaphore, opt, device_id): - """Produce batches to `queues` from `generator_to_serve`.""" - log_level = "INFO" if opt.verbose or device_id == 0 else "WARNING" - init_logger(opt.log_file, log_level=log_level) - set_random_seed(opt.seed, False) - logger.info("BATCH PRODUCER") - logger.info(generator_to_serve) - - for batch, metadata, communication_batch_id in generator_to_serve: - semaphore.acquire() - # Move batch to correspond device_id when consumer iterate - # hack to dodge unpicklable `dict_keys` - # batch.fields = list(batch.fields) - queue.put((batch, metadata, communication_batch_id)) - - -def consumer(process_fn, opt, device_context, error_queue, batch_queue, semaphore, task_queue_manager): - """Run `process_fn` on `device_id` with data from `batch_queue`.""" - try: - logger.info( - f'global_rank {device_context.global_rank} ' - f'node_rank {device_context.node_rank} ' - f'local_rank {device_context.local_rank}' - ) - logger.info(f'opt.gpu_ranks {opt.gpu_ranks}') - multi_init(opt, device_context.global_rank) - # error_queue not passed (is this intentional?) - process_fn( - opt, - device_context=device_context, - batch_queue=batch_queue, - semaphore=semaphore, - task_queue_manager=task_queue_manager, - ) - - except KeyboardInterrupt: - pass # killed by parent, do nothing - except Exception: - # propagate exception to parent process, keeping original traceback - import traceback - - error_queue.put((opt.gpu_ranks[device_context.node_rank], traceback.format_exc())) - - -class TaskDistributionStrategy(ABC): @abstractmethod def __init__(self, my_corpus_ids: List[str], **kwargs): pass @classmethod @abstractmethod - def from_opt(cls, my_corpus_ids: List[str], opt: dict): + def from_opts(cls, my_corpus_ids: List[str], opts: dict): pass @abstractmethod @@ -442,10 +64,10 @@ def __init__( raise ValueError('Invalid curriculum: no corpus is ready to start in the first step') @classmethod - def from_opt(cls, my_corpus_ids: List[str], opt: dict): - my_weights = [opt.data[corpus_id]['weight'] for corpus_id in my_corpus_ids] + def from_opts(cls, my_corpus_ids: List[str], opts: dict): + my_weights = [opts.tasks[corpus_id]['weight'] for corpus_id in my_corpus_ids] my_introduce_at_training_step = [ - opt.data[corpus_id]['introduce_at_training_step'] for corpus_id in my_corpus_ids + opts.tasks[corpus_id]['introduce_at_training_step'] for corpus_id in my_corpus_ids ] return cls(my_corpus_ids, my_weights, my_introduce_at_training_step) @@ -479,7 +101,7 @@ def __init__(self, my_corpus_ids: List[str]): self.infinite_corpus_ids = cycle(my_corpus_ids) @classmethod - def from_opt(cls, my_corpus_ids: List[str], opt: dict): + def from_opts(cls, my_corpus_ids: List[str], opts: dict): return cls(my_corpus_ids) def sample_corpus_ids( @@ -511,7 +133,7 @@ class TaskSpecs(): decoder_id: List[str] corpus_id: str weight: int - corpus_opt: dict + corpus_opts: dict src_vocab: Any # FIXME: type tgt_vocab: Any encoder_adapter_ids: List[str] @@ -534,11 +156,11 @@ def get_serializable_metadata(self): ) -def get_adapter_ids(opt, corpus_opt, side): - if 'adapters' not in opt or 'adapters' not in corpus_opt: +def get_adapter_ids(opts, corpus_opts, side): + if 'adapters' not in opts or 'adapters' not in corpus_opts: return [] - global_adapters_opt = opt.adapters.get(side, None) - corpus_adapter_opt = corpus_opt['adapters'].get(side, None) + global_adapters_opt = opts.adapters.get(side, None) + corpus_adapter_opt = corpus_opts['adapters'].get(side, None) if not global_adapters_opt or not corpus_adapter_opt: return [] result = [] @@ -607,17 +229,17 @@ def local_rank(self): return self.device_context.local_rank @classmethod - def from_opt(cls, opt: Namespace, world_context: WorldContext): - n_tasks = len(opt.data) + def from_opts(cls, opts: Namespace, world_context: WorldContext): + n_tasks = len(opts.tasks) # Sorting the keys, to ensure that tasks have a consistent order across devices. # This in turn ensures the order in which components are created from those tasks. - corpus_ids = sorted(opt.data.keys()) + corpus_ids = sorted(opts.tasks.keys()) if world_context.is_distributed(): - if any(task.get('node_gpu', None) is not None for task in opt.data.values()): + if any(task.get('node_gpu', None) is not None for task in opts.tasks.values()): node_gpu = [ - tuple(int(y) for y in opt.data[corpus_id]['node_gpu'].split(':', 1)) + tuple(int(y) for y in opts.tasks[corpus_id]['node_gpu'].split(':', 1)) for corpus_id in corpus_ids] else: # When --node_gpu is not set, assume an assigment that fills gpus in rank order @@ -626,24 +248,24 @@ def from_opt(cls, opt: Namespace, world_context: WorldContext): node_gpu = [(0, 0)] * n_tasks enc_sharing_group = [ - opt.data[corpus_id].get('enc_sharing_group', None) for corpus_id in corpus_ids + opts.tasks[corpus_id].get('enc_sharing_group', None) for corpus_id in corpus_ids ] dec_sharing_group = [ - opt.data[corpus_id].get('dec_sharing_group', None) for corpus_id in corpus_ids + opts.tasks[corpus_id].get('dec_sharing_group', None) for corpus_id in corpus_ids ] if any(x is not None for x in enc_sharing_group): - assert all(len(enc_ids) == len(opt.enc_layers) for enc_ids in enc_sharing_group) + assert all(len(enc_ids) == len(opts.enc_layers) for enc_ids in enc_sharing_group) else: # if no encoder sharing groups are defined, # it is assumed that there is only one encoder stack and it is language specific - if not len(opt.enc_layers) == 1: + if not len(opts.enc_layers) == 1: raise Exception('With more than one encoder stack, you must explictly define enc_sharing_group') if any(x is not None for x in dec_sharing_group): - assert all(len(dec_ids) == len(opt.dec_layers) for dec_ids in dec_sharing_group) + assert all(len(dec_ids) == len(opts.dec_layers) for dec_ids in dec_sharing_group) else: # if no decoder sharing groups are defined, # it is assumed that there is only one decoder stack and it is language specific - if not len(opt.dec_layers) == 1: + if not len(opts.dec_layers) == 1: raise Exception('With more than one decoder stack, you must explictly define dec_sharing_group') tasks = [] @@ -655,14 +277,14 @@ def from_opt(cls, opt: Namespace, world_context: WorldContext): node_gpu, corpus_ids ): - corpus_opt = opt.data[corpus_id] - src_lang, tgt_lang = corpus_opt['src_tgt'].split('-', 1) - encoder_id = corpus_opt.get('enc_sharing_group', [src_lang]) - decoder_id = corpus_opt.get('dec_sharing_group', [tgt_lang]) - weight = corpus_opt.get('weight', 1.0) - if 'adapters' in corpus_opt: - encoder_adapter_ids = get_adapter_ids(opt, corpus_opt, 'encoder') - decoder_adapter_ids = get_adapter_ids(opt, corpus_opt, 'decoder') + corpus_opts = opts.tasks[corpus_id] + src_lang, tgt_lang = corpus_opts['src_tgt'].split('-', 1) + encoder_id = corpus_opts.get('enc_sharing_group', [src_lang]) + decoder_id = corpus_opts.get('dec_sharing_group', [tgt_lang]) + weight = corpus_opts.get('weight', 1.0) + if 'adapters' in corpus_opts: + encoder_adapter_ids = get_adapter_ids(opts, corpus_opts, 'encoder') + decoder_adapter_ids = get_adapter_ids(opts, corpus_opts, 'decoder') uses_adapters = True else: encoder_adapter_ids = None @@ -676,7 +298,7 @@ def from_opt(cls, opt: Namespace, world_context: WorldContext): decoder_id=decoder_id, corpus_id=corpus_id, weight=weight, - corpus_opt=corpus_opt, + corpus_opts=corpus_opts, src_vocab=None, tgt_vocab=None, encoder_adapter_ids=encoder_adapter_ids, @@ -686,14 +308,14 @@ def from_opt(cls, opt: Namespace, world_context: WorldContext): return cls( tasks, world_context=world_context, - tasks_per_communication_batch=opt.accum_count, + tasks_per_communication_batch=opts.accum_count, uses_adapters=uses_adapters, ) - def global_to_local(self, node_rank, local_rank, opt): + def global_to_local(self, node_rank, local_rank, opts): assert node_rank is not None assert local_rank is not None - task_distribution_strategy = self._get_strategy(node_rank=node_rank, local_rank=local_rank, opt=opt) + task_distribution_strategy = self._get_strategy(node_rank=node_rank, local_rank=local_rank, opts=opts) device_context = self.world_context.global_to_local(node_rank, local_rank) return self.__class__( self.tasks, @@ -706,15 +328,15 @@ def global_to_local(self, node_rank, local_rank, opt): uses_adapters=self.uses_adapters, ) - def _get_strategy(self, node_rank, local_rank, opt): + def _get_strategy(self, node_rank, local_rank, opts): assert node_rank is not None assert local_rank is not None # Global TQM does not have a task distribution strategy, but the local ones do my_corpus_ids = [task.corpus_id for task in self._tasks_on_device(node_rank, local_rank)] try: - strategy = TASK_DISTRIBUTION_STRATEGIES[opt.task_distribution_strategy].from_opt( + strategy = TASK_DISTRIBUTION_STRATEGIES[opts.task_distribution_strategy].from_opts( my_corpus_ids=my_corpus_ids, - opt=opt, + opts=opts, ) return strategy except Exception as e: @@ -936,11 +558,11 @@ def get_fields(self, side: str, fields_dict): raise RuntimeError # FIXME: merge with below - def get_vocabularies(self, opt: Namespace, side: str): + def get_vocabularies(self, opts: Namespace, side: str): result = [] for task in self.get_tasks(): lang = self.src_lang if side == 'src' else self.tgt_lang - vocab_path = opt.__getattribute__(f'{side}_vocab')[lang] + vocab_path = opts.__getattribute__(f'{side}_vocab')[lang] result.append((lang, vocab_path)) return result diff --git a/onmt/inputters/__init__.py b/mammoth/inputters/__init__.py similarity index 61% rename from onmt/inputters/__init__.py rename to mammoth/inputters/__init__.py index 8c32e4d3..4d3d1dbe 100644 --- a/onmt/inputters/__init__.py +++ b/mammoth/inputters/__init__.py @@ -1,4 +1,4 @@ -"""The point of this package is to provide a minimal viable product with: +"""The point of this package is to provide: - vocab loading (cf. vocab.py) - token-counts based batch sampler (cf. dataloader.py) - on the fly pad, bos, eos, unk handling (cf. dataset.py) @@ -6,9 +6,9 @@ - multiple parallel corpora, in accordance with TaskDistributor (cf. distributed.py) """ -from onmt.inputters.dataloader import build_dataloader, DynamicDatasetIter -from onmt.inputters.dataset import get_corpus, build_vocab_counts, ParallelCorpus -from onmt.inputters.vocab import get_vocab, DEFAULT_SPECIALS +from mammoth.inputters.dataloader import build_dataloader, DynamicDatasetIter +from mammoth.inputters.dataset import get_corpus, build_vocab_counts, ParallelCorpus +from mammoth.inputters.vocab import get_vocab, DEFAULT_SPECIALS __all__ = [ diff --git a/onmt/inputters/dataloader.py b/mammoth/inputters/dataloader.py similarity index 96% rename from onmt/inputters/dataloader.py rename to mammoth/inputters/dataloader.py index ed707c94..7e26987b 100644 --- a/onmt/inputters/dataloader.py +++ b/mammoth/inputters/dataloader.py @@ -4,8 +4,8 @@ import torch -from onmt.inputters.dataset import get_corpus -from onmt.utils.logging import logger +from mammoth.inputters.dataset import get_corpus +from mammoth.utils.logging import logger def infinite_iterator(iterable): @@ -13,7 +13,7 @@ def infinite_iterator(iterable): def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=None, cycle=True, as_iter=True): - """Convert an onmt.inputters.ParallelCorpus into an infinite iterator of batches""" + """Convert an mammoth.inputters.ParallelCorpus into an infinite iterator of batches""" if not cycle: loader = InferenceBatcher(dataset, batch_size) else: @@ -187,7 +187,7 @@ class DynamicDatasetIter(object): batch_size (int): numbers of examples in a batch; batch_size_multiple (int): make batch size multiply of this; data_type (str): input data type, currently only text; - bucket_size (int): accum this number of examples in a dynamic dataset; + pool_size (int): accum this number of examples in a dynamic dataset; skip_empty_level (str): security level when encouter empty line; stride (int): iterate data files with this stride; offset (int): iterate data files with this offset. @@ -209,7 +209,7 @@ def __init__( batch_size, batch_size_multiple, data_type="text", - bucket_size=2048, + pool_size=2048, n_buckets=1024, skip_empty_level='warning', ): @@ -225,7 +225,7 @@ def __init__( self.batch_size = batch_size self.batch_size_multiple = batch_size_multiple self.device = 'cpu' - self.bucket_size = bucket_size + self.pool_size = pool_size self.n_buckets = n_buckets if skip_empty_level not in ['silent', 'warning', 'error']: raise ValueError(f"Invalid argument skip_empty_level={skip_empty_level}") @@ -242,7 +242,7 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra return cls( task_queue_manager, opts, - opts.data, + opts.tasks, transforms_cls, vocabs_dict, is_train, @@ -250,7 +250,7 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra batch_size, batch_size_multiple, data_type=opts.data_type, - bucket_size=opts.bucket_size, + pool_size=opts.pool_size, n_buckets=opts.n_buckets, skip_empty_level=opts.skip_empty_level, ) @@ -275,7 +275,7 @@ def _init_datasets(self): # Case 2: we are validation (hence self.is_train := False), we need an iterator # if and only the task defines validation data, i.e. if the key `path_valid_src` # is defined - if self.is_train or self.opts.data[task.corpus_id].get('path_valid_src', None) is not None: + if self.is_train or self.opts.tasks[task.corpus_id].get('path_valid_src', None) is not None: corpus = get_corpus( self.opts, task, src_vocab, tgt_vocab, is_train=self.is_train ).to(device) @@ -285,7 +285,7 @@ def _init_datasets(self): corpus, self.batch_size, self.batch_type, - self.bucket_size, + self.pool_size, n_buckets=self.n_buckets, cycle=self.is_train, as_iter=self.is_train, diff --git a/onmt/inputters/dataset.py b/mammoth/inputters/dataset.py similarity index 95% rename from onmt/inputters/dataset.py rename to mammoth/inputters/dataset.py index a8044dec..6bb2d1a9 100644 --- a/onmt/inputters/dataset.py +++ b/mammoth/inputters/dataset.py @@ -8,10 +8,10 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.data import IterableDataset -from onmt.constants import DefaultTokens -from onmt.transforms import TransformPipe, get_transforms_cls, make_transforms -from onmt.utils.logging import logger -from onmt.inputters.vocab import Vocab +from mammoth.constants import DefaultTokens +from mammoth.transforms import TransformPipe, get_transforms_cls, make_transforms +from mammoth.utils.logging import logger +from mammoth.inputters.vocab import Vocab @dataclass @@ -102,7 +102,7 @@ def __init__( self.is_train = is_train self.corpus_id = task.corpus_id - # FIXME: most likely redundant with onmt.transforms.tokenize + # FIXME: most likely redundant with mammoth.transforms.tokenize def _tokenize(self, string, side='src'): """Split string, accompanied by a drumroll""" return string.split() @@ -177,7 +177,7 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool = # get transform classes to infer special tokens # FIXME ensure TQM properly initializes transform with global if necessary vocabs = {'src': src_vocab, 'tgt': tgt_vocab} - corpus_opts = opts.data[task.corpus_id] + corpus_opts = opts.tasks[task.corpus_id] transforms_to_apply = corpus_opts.get('transforms', None) transforms_to_apply = transforms_to_apply or opts.get('transforms', None) transforms_to_apply = transforms_to_apply or [] @@ -245,8 +245,8 @@ def build_vocab_counts(opts, corpus_id, transforms, n_sample=3): corpora = { corpus_id: read_examples_from_files( - opts.data[corpus_id]["path_src"], - opts.data[corpus_id]["path_tgt"], + opts.tasks[corpus_id]["path_src"], + opts.tasks[corpus_id]["path_tgt"], # FIXME this is likely not working transforms_fn=TransformPipe(transforms).apply if transforms else lambda x: x, ) diff --git a/onmt/inputters/vocab.py b/mammoth/inputters/vocab.py similarity index 97% rename from onmt/inputters/vocab.py rename to mammoth/inputters/vocab.py index b2ed194b..cc3cbd31 100644 --- a/onmt/inputters/vocab.py +++ b/mammoth/inputters/vocab.py @@ -3,8 +3,8 @@ import itertools import os -from onmt.utils.logging import logger -from onmt.constants import DefaultTokens +from mammoth.utils.logging import logger +from mammoth.constants import DefaultTokens DEFAULT_SPECIALS = (DefaultTokens.BOS, DefaultTokens.EOS, DefaultTokens.UNK, DefaultTokens.PAD) diff --git a/onmt/model_builder.py b/mammoth/model_builder.py similarity index 69% rename from onmt/model_builder.py rename to mammoth/model_builder.py index 69abb0a7..35dfec85 100644 --- a/onmt/model_builder.py +++ b/mammoth/model_builder.py @@ -7,95 +7,92 @@ from torch.nn.init import xavier_uniform_ from collections import defaultdict -# from torchtext.legacy.data import Field -import onmt.modules +import mammoth.modules -from onmt.models.adapters import ( +from mammoth.models.adapters import ( Adapter, EncoderAdapterLayer, DecoderAdapterLayer, ) -from onmt.constants import ModelTask, DefaultTokens -from onmt.decoders.layer_stack_decoder import LayerStackDecoder -from onmt.encoders.layer_stack_encoder import LayerStackEncoder -from onmt.modules import Embeddings -from onmt.modules.embeddings import PluggableEmbeddings -from onmt.modules.util_class import Cast -from onmt.utils.logging import logger -from onmt.utils.misc import use_gpu -from onmt.utils.module_splitter import _combine_ordered_dicts -from onmt.utils.parse import ArgumentParser +from mammoth.constants import ModelTask, DefaultTokens +from mammoth.modules.layer_stack_decoder import LayerStackDecoder +from mammoth.modules.layer_stack_encoder import LayerStackEncoder +from mammoth.modules import Embeddings +from mammoth.modules.embeddings import PluggableEmbeddings +from mammoth.modules.util_class import Cast +from mammoth.utils.logging import logger +from mammoth.utils.misc import use_gpu +from mammoth.utils.module_splitter import _combine_ordered_dicts +from mammoth.utils.parse import ArgumentParser -from onmt.attention_bridge import AttentionBridge +from mammoth.modules.attention_bridge import AttentionBridge -def build_embeddings(opt, vocab, for_encoder=True): +def build_embeddings(opts, vocab, for_encoder=True): """ Args: - opt: the option in current environment. + opts: the option in current environment. vocab: stoi-ish object. for_encoder(bool): build Embeddings for encoder or decoder? """ - emb_dim = opt.src_word_vec_size if for_encoder else opt.tgt_word_vec_size word_padding_idx = vocab.stoi[DefaultTokens.PAD] - opt.word_padding_idx = word_padding_idx + opts.word_padding_idx = word_padding_idx - freeze_word_vecs = opt.freeze_word_vecs_enc if for_encoder else opt.freeze_word_vecs_dec + freeze_word_vecs = opts.freeze_word_vecs_enc if for_encoder else opts.freeze_word_vecs_dec emb = Embeddings( - word_vec_size=emb_dim, - position_encoding=opt.position_encoding, - dropout=opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + word_vec_size=opts.model_dim, + position_encoding=opts.position_encoding, + dropout=opts.dropout[0] if type(opts.dropout) is list else opts.dropout, word_padding_idx=word_padding_idx, word_vocab_size=len(vocab), - sparse=opt.optim == "sparseadam", freeze_word_vecs=freeze_word_vecs, ) return emb -def build_encoder(opt, embeddings, task_queue_manager): +def build_encoder(opts, embeddings, task_queue_manager): """ Various encoder dispatcher function. Args: - opt: the option in current environment. + opts: the option in current environment. embeddings (Embeddings): vocab embeddings for this encoder. """ - assert opt.encoder_type == 'transformer', 'Only Transformer is supported' - return LayerStackEncoder.from_opt(opt, embeddings, task_queue_manager) + assert opts.encoder_type == 'transformer', 'Only Transformer is supported' + return LayerStackEncoder.from_opts(opts, embeddings, task_queue_manager) -def build_decoder(opt, embeddings, task_queue_manager): +def build_decoder(opts, embeddings, task_queue_manager): """ Various decoder dispatcher function. Args: - opt: the option in current environment. + opts: the option in current environment. embeddings (Embeddings): vocab embeddings for this decoder. """ - assert opt.decoder_type == 'transformer', 'Only Transformer is supported' - return LayerStackDecoder.from_opt(opt, embeddings, task_queue_manager) + assert opts.decoder_type == 'transformer', 'Only Transformer is supported' + return LayerStackDecoder.from_opts(opts, embeddings, task_queue_manager) -def load_test_multitask_model(opt, model_path=None): +def load_test_multitask_model(opts, model_path=None): """If a checkpoint ending with ".pt" returns a full model otherwise it builds a bilingual model""" if model_path is None: - model_path = opt.models[0] + model_path = opts.models[0] - opt.lang_pair = opt.lang_pair if opt.lang_pair else f'{opt.src_lang}-{opt.tgt_lang}' + opts.lang_pair = opts.lang_pair if opts.lang_pair else f'{opts.src_lang}-{opts.tgt_lang}' if model_path.endswith('.pt'): - return load_test_model(opt, model_path) + return load_test_model(opts, model_path) else: checkpoint_modules = [ - (f'encoder.embeddings.embeddings_{opt.src_lang}.', f'src_embeddings_{opt.src_lang}'), - (f'decoder.embeddings.embeddings_{opt.tgt_lang}.', f'tgt_embeddings_{opt.tgt_lang}'), - (f'generator.generator_{opt.tgt_lang}.', f'generator_{opt.tgt_lang}'), + (f'encoder.embeddings.embeddings_{opts.src_lang}.', f'src_embeddings_{opts.src_lang}'), + (f'decoder.embeddings.embeddings_{opts.tgt_lang}.', f'tgt_embeddings_{opts.tgt_lang}'), + (f'generator.generator_{opts.tgt_lang}.', f'generator_{opts.tgt_lang}'), ('attention_bridge.', 'attention_bridge'), ] - for layer_stack_idx, layer_stack_opt in enumerate(opt.stack['encoder']): + for layer_stack_idx, layer_stack_opt in enumerate(opts.stack['encoder']): layer_stack_key = layer_stack_opt['id'] checkpoint_modules.append( ( @@ -110,7 +107,7 @@ def load_test_multitask_model(opt, model_path=None): f'encoder_adapter_{layer_stack_idx}_{layer_stack_key}_{adapter_group}_{sub_id}' ) ) - for layer_stack_idx, layer_stack_opt in enumerate(opt.stack['decoder']): + for layer_stack_idx, layer_stack_opt in enumerate(opts.stack['decoder']): layer_stack_key = layer_stack_opt['id'] checkpoint_modules.append( ( @@ -131,8 +128,8 @@ def load_test_multitask_model(opt, model_path=None): (prefix, f'{model_path}_{key}.pt') for (prefix, key) in checkpoint_modules ] - opt.model_frame = model_path + '_frame.pt' - frame = torch.load(opt.model_frame, map_location=lambda storage, loc: storage) + opts.model_frame = model_path + '_frame.pt' + frame = torch.load(opts.model_frame, map_location=lambda storage, loc: storage) checkpoint_state_dicts = { prefix: torch.load(path, map_location=lambda storage, loc: storage) @@ -142,20 +139,20 @@ def load_test_multitask_model(opt, model_path=None): combined_state_dict = _combine_ordered_dicts(checkpoint_state_dicts) vocabs_dict = { - 'src': frame["vocab"].get(('src', opt.src_lang)), - 'tgt': frame["vocab"].get(('tgt', opt.tgt_lang)), + 'src': frame["vocab"].get(('src', opts.src_lang)), + 'tgt': frame["vocab"].get(('tgt', opts.tgt_lang)), } # FIXME # fields["indices"] = Field(use_vocab=False, dtype=torch.long, sequential=False) - model_opt = ArgumentParser.ckpt_model_opts(frame['opt']) + model_opts = ArgumentParser.ckpt_model_opts(frame['opts']) # Avoid functionality on inference - model_opt.update_vocab = False + model_opts.update_vocab = False model = create_bilingual_model( - src_lang=opt.src_lang, - tgt_lang=opt.tgt_lang, - opt_stack=opt.stack, - model_opt=model_opt, + src_lang=opts.src_lang, + tgt_lang=opts.tgt_lang, + opt_stack=opts.stack, + model_opts=model_opts, vocabs_dict=vocabs_dict ) model_params = {name for name, p in model.named_parameters()} @@ -168,24 +165,24 @@ def load_test_multitask_model(opt, model_path=None): if key not in combined_state_dict: print(f'Key missing {key}') model.load_state_dict(combined_state_dict) - device = torch.device("cuda" if use_gpu(opt) else "cpu") + device = torch.device("cuda" if use_gpu(opts) else "cpu") model.to(device) model.eval() - return vocabs_dict, model, model_opt + return vocabs_dict, model, model_opts -def load_test_model(opt, model_path=None): +def load_test_model(opts, model_path=None): if model_path is None: - model_path = opt.models[0] + model_path = opts.models[0] - if len(opt.models) > 1: - model_path_enc = opt.models[0] + if len(opts.models) > 1: + model_path_enc = opts.models[0] checkpoint = torch.load(model_path_enc, map_location=lambda storage, loc: storage) model = checkpoint['whole_model'] - model_path_dec = opt.models[1] + model_path_dec = opts.models[1] model_dec = torch.load(model_path_dec, map_location=lambda storage, loc: storage)['whole_model'] model.decoder = model_dec.decoder model.generator = model_dec.generator @@ -193,17 +190,17 @@ def load_test_model(opt, model_path=None): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model = checkpoint['whole_model'] - model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt']) - ArgumentParser.update_model_opts(model_opt) - ArgumentParser.validate_model_opts(model_opt) + model_opts = ArgumentParser.ckpt_model_opts(checkpoint['opts']) + ArgumentParser.update_model_opts(model_opts) + ArgumentParser.validate_model_opts(model_opts) vocabs = checkpoint['vocab'] print("VOCABS") print(vocabs) - if opt.gpu != -1: + if opts.gpu != -1: device = torch.device("cuda") model.to(device) - lang_pair = opt.lang_pair + lang_pair = opts.lang_pair src_lang, tgt_lang = lang_pair.split("-") # FIXME vocabs_dict = {} @@ -213,48 +210,48 @@ def load_test_model(opt, model_path=None): # fields["indices"] = indices # Avoid functionality on inference - model_opt.update_vocab = False + model_opts.update_vocab = False - if opt.fp32: + if opts.fp32: model.float() - elif opt.int8: - if opt.gpu >= 0: + elif opts.int8: + if opts.gpu >= 0: raise ValueError("Dynamic 8-bit quantization is not supported on GPU") torch.quantization.quantize_dynamic(model, inplace=True) model.eval() model.generator.eval() - return vocabs_dict, model, model_opt + return vocabs_dict, model, model_opts def create_bilingual_model( - src_lang, tgt_lang, opt_stack, model_opt, vocabs_dict + src_lang, tgt_lang, opt_stack, model_opts, vocabs_dict ): """For translation.""" generators_md = nn.ModuleDict() - src_emb = build_src_emb(model_opt, vocabs_dict['src']) - tgt_emb = build_tgt_emb(model_opt, vocabs_dict['tgt']) + src_emb = build_src_emb(model_opts, vocabs_dict['src']) + tgt_emb = build_tgt_emb(model_opts, vocabs_dict['tgt']) pluggable_src_emb = PluggableEmbeddings({src_lang: src_emb}) pluggable_tgt_emb = PluggableEmbeddings({tgt_lang: tgt_emb}) pluggable_src_emb.activate(src_lang) pluggable_tgt_emb.activate(tgt_lang) - encoder = LayerStackEncoder.from_trans_opt(model_opt, pluggable_src_emb, opt_stack) - decoder = LayerStackDecoder.from_trans_opt(model_opt, pluggable_tgt_emb, opt_stack) - generator = build_generator(model_opt, len(vocabs_dict['tgt']), tgt_emb) + encoder = LayerStackEncoder.from_trans_opt(model_opts, pluggable_src_emb, opt_stack) + decoder = LayerStackDecoder.from_trans_opt(model_opts, pluggable_tgt_emb, opt_stack) + generator = build_generator(model_opts, len(vocabs_dict['tgt']), tgt_emb) generators_md.add_module(f'generator_{tgt_lang}', generator) - attention_bridge = AttentionBridge.from_opt(model_opt) + attention_bridge = AttentionBridge.from_opts(model_opts) - nmt_model = onmt.models.NMTModel( + nmt_model = mammoth.models.NMTModel( encoder=encoder, decoder=decoder, attention_bridge=attention_bridge ) - if uses_adapters(model_opt): + if uses_adapters(model_opts): logger.info('Creating adapters...') - create_bilingual_adapters(nmt_model, model_opt, src_lang, tgt_lang, opt_stack) + create_bilingual_adapters(nmt_model, model_opts, src_lang, tgt_lang, opt_stack) else: logger.info('Does not use adapters...') print('built model:') @@ -264,18 +261,18 @@ def create_bilingual_model( return nmt_model -def build_src_emb(model_opt, src_vocab): +def build_src_emb(model_opts, src_vocab): # Build embeddings. - if model_opt.model_type == "text": - src_emb = build_embeddings(model_opt, src_vocab) + if model_opts.model_type == "text": + src_emb = build_embeddings(model_opts, src_vocab) else: src_emb = None return src_emb -def build_tgt_emb(model_opt, tgt_vocab): +def build_tgt_emb(model_opts, tgt_vocab): # Build embeddings. - tgt_emb = build_embeddings(model_opt, tgt_vocab, for_encoder=False) + tgt_emb = build_embeddings(model_opts, tgt_vocab, for_encoder=False) # if share_embeddings: # tgt_emb.word_lut.weight = src_emb.word_lut.weight @@ -284,15 +281,15 @@ def build_tgt_emb(model_opt, tgt_vocab): def build_task_specific_model( - model_opt, + model_opts, vocabs_dict, device, task_queue_manager, checkpoint, ): logger.info(f'TaskQueueManager: {task_queue_manager}') - if not model_opt.model_task == ModelTask.SEQ2SEQ: - raise ValueError(f"Only ModelTask.SEQ2SEQ works - {model_opt.model_task} task") + if not model_opts.model_task == ModelTask.SEQ2SEQ: + raise ValueError(f"Only ModelTask.SEQ2SEQ works - {model_opts.model_task} task") src_embs = dict() tgt_embs = dict() @@ -301,42 +298,42 @@ def build_task_specific_model( # FIXME: it's getting late and I just want this to compile for side, lang, _, vocab in task_queue_manager.get_vocabs(side='src', vocabs_dict=vocabs_dict): - src_emb = build_src_emb(model_opt, vocab) + src_emb = build_src_emb(model_opts, vocab) src_embs[lang] = src_emb pluggable_src_emb = PluggableEmbeddings(src_embs) - encoder = build_only_enc(model_opt, pluggable_src_emb, task_queue_manager) + encoder = build_only_enc(model_opts, pluggable_src_emb, task_queue_manager) for side, lang, _, vocab in task_queue_manager.get_vocabs(side='tgt', vocabs_dict=vocabs_dict): - tgt_emb = build_tgt_emb(model_opt, vocab) + tgt_emb = build_tgt_emb(model_opts, vocab) tgt_embs[lang] = tgt_emb - generator = build_generator(model_opt, len(vocab), tgt_emb) + generator = build_generator(model_opts, len(vocab), tgt_emb) generators_md.add_module(f'generator_{lang}', generator) pluggable_tgt_emb = PluggableEmbeddings(tgt_embs) - decoder = build_only_dec(model_opt, pluggable_tgt_emb, task_queue_manager) + decoder = build_only_dec(model_opts, pluggable_tgt_emb, task_queue_manager) # TODO: implement hierarchical approach to layer sharing - attention_bridge = AttentionBridge.from_opt(model_opt) + attention_bridge = AttentionBridge.from_opts(model_opts) - if model_opt.param_init != 0.0: + if model_opts.param_init != 0.0: for p in attention_bridge.parameters(): - p.data.uniform_(-model_opt.param_init, model_opt.param_init) - if model_opt.param_init_glorot: + p.data.uniform_(-model_opts.param_init, model_opts.param_init) + if model_opts.param_init_glorot: for p in attention_bridge.parameters(): if p.dim() > 1: xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) - if model_opt.model_dtype == 'fp16' and model_opt.optim == 'fusedadam': + if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam': attention_bridge.half() - nmt_model = onmt.models.NMTModel( + nmt_model = mammoth.models.NMTModel( encoder=encoder, decoder=decoder, attention_bridge=attention_bridge ) - if uses_adapters(model_opt): + if uses_adapters(model_opts): logger.info('Creating adapters...') - create_all_adapters(nmt_model, model_opt, task_queue_manager) + create_all_adapters(nmt_model, model_opts, task_queue_manager) print('built model:') print(nmt_model) @@ -365,57 +362,54 @@ def has_grad_hook(module, input, output) -> None: return nmt_model, generators_md -def build_only_enc(model_opt, src_emb, task_queue_manager): +def build_only_enc(model_opts, src_emb, task_queue_manager): """Truly only builds encoder: no embeddings""" - encoder = build_encoder(model_opt, src_emb, task_queue_manager) - if model_opt.param_init != 0.0: + encoder = build_encoder(model_opts, src_emb, task_queue_manager) + if model_opts.param_init != 0.0: for p in encoder.parameters(): - p.data.uniform_(-model_opt.param_init, model_opt.param_init) - if model_opt.param_init_glorot: + p.data.uniform_(-model_opts.param_init, model_opts.param_init) + if model_opts.param_init_glorot: for p in encoder.parameters(): if p.dim() > 1: xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) - if model_opt.model_dtype == 'fp16' and model_opt.optim == 'fusedadam': + if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam': encoder.half() return encoder -def build_only_dec(model_opt, tgt_emb, task_queue_manager): - decoder = build_decoder(model_opt, tgt_emb, task_queue_manager) +def build_only_dec(model_opts, tgt_emb, task_queue_manager): + decoder = build_decoder(model_opts, tgt_emb, task_queue_manager) - if model_opt.param_init != 0.0: + if model_opts.param_init != 0.0: for p in decoder.parameters(): - p.data.uniform_(-model_opt.param_init, model_opt.param_init) - if model_opt.param_init_glorot: + p.data.uniform_(-model_opts.param_init, model_opts.param_init) + if model_opts.param_init_glorot: for p in decoder.parameters(): if p.dim() > 1: xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) - if model_opt.model_dtype == 'fp16' and model_opt.optim == 'fusedadam': + if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam': decoder.half() return decoder -def build_generator(model_opt, n_tgts, tgt_emb): +def build_generator(model_opts, n_tgts, tgt_emb): # Build Generator. - assert not model_opt.copy_attn, 'copy_attn not supported' - if model_opt.generator_function == "sparsemax": - gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) - else: - gen_func = nn.LogSoftmax(dim=-1) + assert not model_opts.copy_attn, 'copy_attn not supported' + gen_func = nn.LogSoftmax(dim=-1) generator = nn.Sequential( - nn.Linear(model_opt.dec_rnn_size, n_tgts), Cast(torch.float32), gen_func + nn.Linear(model_opts.model_dim, n_tgts), Cast(torch.float32), gen_func ) - if model_opt.share_decoder_embeddings: + if model_opts.share_decoder_embeddings: generator[0].weight = tgt_emb.word_lut.weight - if model_opt.param_init != 0.0: + if model_opts.param_init != 0.0: for p in generator.parameters(): - p.data.uniform_(-model_opt.param_init, model_opt.param_init) - if model_opt.param_init_glorot: + p.data.uniform_(-model_opts.param_init, model_opts.param_init) + if model_opts.param_init_glorot: for p in generator.parameters(): if p.dim() > 1: xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) @@ -455,7 +449,7 @@ def build_generator(model_opt, n_tgts, tgt_emb): def build_base_model_langspec( - model_opt, + model_opts, vocabs_dict, gpu, task_queue_manager, @@ -464,10 +458,10 @@ def build_base_model_langspec( """Build a model from opts. Args: - model_opt: the option loaded from checkpoint. It's important that + model_opts: the option loaded from checkpoint. It's important that the opts have been updated and validated. See - :class:`onmt.utils.parse.ArgumentParser`. - vocabs_dict (dict[str, onmt.inputters.Vocab]): + :class:`mammoth.utils.parse.ArgumentParser`. + vocabs_dict (dict[str, mammoth.inputters.Vocab]): `Vocab` objects for the model. gpu (bool): whether to use gpu. checkpoint: the model gnerated by train phase, or a resumed snapshot @@ -480,9 +474,9 @@ def build_base_model_langspec( # for back compat when attention_dropout was not defined try: - model_opt.attention_dropout + model_opts.attention_dropout except AttributeError: - model_opt.attention_dropout = model_opt.dropout + model_opts.attention_dropout = model_opts.dropout # Build Model logger.info("MODEL BUILDER") @@ -492,7 +486,7 @@ def build_base_model_langspec( device = torch.device("cpu") logger.info(device) model, generators_md = build_task_specific_model( - model_opt=model_opt, + model_opts=model_opts, vocabs_dict=vocabs_dict, device=device, task_queue_manager=task_queue_manager, @@ -505,11 +499,11 @@ def build_base_model_langspec( return model, generators_md -def uses_adapters(opt): - return 'adapters' in opt and opt.adapters +def uses_adapters(opts): + return 'adapters' in opts and opts.adapters -def create_all_adapters(model, opt, task_queue_manager): +def create_all_adapters(model, opts, task_queue_manager): my_enc_adapter_ids = set() my_dec_adapter_ids = set() adapter_to_encoder_ids = defaultdict(set) @@ -525,7 +519,7 @@ def create_all_adapters(model, opt, task_queue_manager): adapter_to_decoder_ids[adapter_id].add(tuple(task.decoder_id)) _create_adapters( model, - opt, + opts, my_enc_adapter_ids, adapter_to_encoder_ids, my_dec_adapter_ids, @@ -533,7 +527,7 @@ def create_all_adapters(model, opt, task_queue_manager): ) -def create_bilingual_adapters(model, opt, src_lang, tgt_lang, opt_stack): +def create_bilingual_adapters(model, opts, src_lang, tgt_lang, opt_stack): my_enc_adapter_ids = [] my_dec_adapter_ids = [] adapter_to_encoder_ids = {} @@ -556,7 +550,7 @@ def create_bilingual_adapters(model, opt, src_lang, tgt_lang, opt_stack): _create_adapters( model, - opt, + opts, my_enc_adapter_ids, adapter_to_encoder_ids, my_dec_adapter_ids, @@ -566,7 +560,7 @@ def create_bilingual_adapters(model, opt, src_lang, tgt_lang, opt_stack): def _create_adapters( model, - opt, + opts, my_enc_adapter_ids, adapter_to_encoder_ids, my_dec_adapter_ids, @@ -574,14 +568,14 @@ def _create_adapters( ): my_enc_adapter_ids = [tuple(item) for item in my_enc_adapter_ids] my_dec_adapter_ids = [tuple(item) for item in my_dec_adapter_ids] - for adapter_group, adapter_opts in opt.adapters['encoder'].items(): + for adapter_group, adapter_opts in opts.adapters['encoder'].items(): layer_stack_index = adapter_opts['layer_stack_index'] for sub_id in adapter_opts['ids']: adapter_id_long = (layer_stack_index, adapter_group, sub_id) if adapter_id_long not in my_enc_adapter_ids: continue adapter = Adapter(adapter_group, sub_id) - input_dim = opt.rnn_size + input_dim = opts.model_dim hidden_dim = adapter_opts['hidden_size'] # all stacks to which this adapter should be added @@ -602,14 +596,14 @@ def _create_adapters( layer_stack_index=layer_stack_index, module_ids=adapted_stacks, ) - for adapter_group, adapter_opts in opt.adapters['decoder'].items(): + for adapter_group, adapter_opts in opts.adapters['decoder'].items(): layer_stack_index = adapter_opts['layer_stack_index'] for sub_id in adapter_opts['ids']: adapter_id_long = (layer_stack_index, adapter_group, sub_id) if adapter_id_long not in my_dec_adapter_ids: continue adapter = Adapter(adapter_group, sub_id) - input_dim = opt.rnn_size + input_dim = opts.model_dim hidden_dim = adapter_opts['hidden_size'] adapted_stacks = set( @@ -631,12 +625,12 @@ def _create_adapters( ) -def build_model(model_opt, opt, vocabs_dict, task_queue_manager, checkpoint): +def build_model(model_opts, opts, vocabs_dict, task_queue_manager, checkpoint): logger.info('Building model...') model, generators_md = build_base_model_langspec( - model_opt=model_opt, + model_opts=model_opts, vocabs_dict=vocabs_dict, - gpu=use_gpu(opt), + gpu=use_gpu(opts), task_queue_manager=task_queue_manager, checkpoint=checkpoint, ) diff --git a/mammoth/models/__init__.py b/mammoth/models/__init__.py new file mode 100644 index 00000000..30263cd6 --- /dev/null +++ b/mammoth/models/__init__.py @@ -0,0 +1,5 @@ +"""Module defining models.""" +from mammoth.models.model_saver import build_model_saver, ModelSaver +from mammoth.models.model import NMTModel + +__all__ = ["build_model_saver", "ModelSaver", "NMTModel"] diff --git a/onmt/models/adapters.py b/mammoth/models/adapters.py similarity index 98% rename from onmt/models/adapters.py rename to mammoth/models/adapters.py index 26311237..c2e90c5d 100644 --- a/onmt/models/adapters.py +++ b/mammoth/models/adapters.py @@ -8,9 +8,9 @@ from abc import ABC from collections import defaultdict -from onmt.encoders import TransformerEncoder -from onmt.decoders import TransformerDecoder -from onmt.rmsnorm_torch import RMSNorm +from mammoth.modules import TransformerEncoder +from mammoth.modules import TransformerDecoder +from mammoth.rmsnorm_torch import RMSNorm class AdapterLayer(ABC, nn.Module): diff --git a/onmt/models/model.py b/mammoth/models/model.py similarity index 56% rename from onmt/models/model.py rename to mammoth/models/model.py index 36097e11..27bdf44a 100644 --- a/onmt/models/model.py +++ b/mammoth/models/model.py @@ -48,8 +48,8 @@ class NMTModel(BaseModel): Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic encoder + decoder model. Args: - encoder (onmt.encoders.EncoderBase): an encoder object - decoder (onmt.decoders.DecoderBase): a decoder object + encoder (mammoth.encoders.EncoderBase): an encoder object + decoder (mammoth.decoders.DecoderBase): a decoder object """ def __init__(self, encoder, decoder, attention_bridge): @@ -102,68 +102,3 @@ def count_parameters(self, log=print): log('decoder: {}'.format(dec)) log('* number of parameters: {}'.format(enc + dec)) return enc, dec - - -class LanguageModel(BaseModel): - """ - Core trainable object in OpenNMT. Implements a trainable interface - for a simple, generic decoder only model. - Currently TransformerLMDecoder is the only LM decoder implemented - Args: - decoder (onmt.decoders.TransformerLMDecoder): a transformer decoder - """ - - def __init__(self, encoder=None, decoder=None): - super(LanguageModel, self).__init__(encoder, decoder) - if encoder is not None: - raise ValueError("LanguageModel should not be used with an encoder") - self.decoder = decoder - - def forward(self, src, tgt, lengths, bptt=False, with_align=False): - """Forward propagate a `src` and `tgt` pair for training. - Possible initialized with a beginning decoder state. - Args: - src (Tensor): A source sequence passed to decoder. - typically for inputs this will be a padded `LongTensor` - of size ``(len, batch, features)``. However, may be an - image or other generic input depending on decoder. - tgt (LongTensor): A target sequence passed to decoder. - Size ``(tgt_len, batch, features)``. - lengths(LongTensor): The src lengths, pre-padding ``(batch,)``. - bptt (Boolean): A flag indicating if truncated bptt is set. - If reset then init_state - with_align (Boolean): A flag indicating whether output alignment, - Only valid for transformer decoder. - Returns: - (FloatTensor, dict[str, FloatTensor]): - * decoder output ``(tgt_len, batch, hidden)`` - * dictionary attention dists of ``(tgt_len, batch, src_len)`` - """ - - if not bptt: - self.decoder.init_state() - dec_out, attns = self.decoder(src, memory_bank=None, memory_lengths=lengths, with_align=with_align) - return dec_out, attns - - def update_dropout(self, dropout): - self.decoder.update_dropout(dropout) - - def count_parameters(self, log=print): - """Count number of parameters in model (& print with `log` callback). - Returns: - (int, int): - * encoder side parameter count - * decoder side parameter count - """ - - enc, dec = 0, 0 - for name, param in self.named_parameters(): - if "decoder" in name: - dec += param.nelement() - - if callable(log): - # No encoder in LM, seq2seq count formatting kept - log("total encoder parameters: {}".format(enc)) - log("total decoder parameters: {}".format(dec)) - log("* number of parameters: {}".format(enc + dec)) - return enc, dec diff --git a/onmt/models/model_saver.py b/mammoth/models/model_saver.py similarity index 91% rename from onmt/models/model_saver.py rename to mammoth/models/model_saver.py index 44ea45f4..7ece5078 100644 --- a/onmt/models/model_saver.py +++ b/mammoth/models/model_saver.py @@ -1,20 +1,20 @@ import os from collections import deque -from onmt.utils.logging import logger +from mammoth.utils.logging import logger import torch import torch.nn as nn -from onmt.utils.module_splitter import explode_model +from mammoth.utils.module_splitter import explode_model -def build_model_saver(model_opt, opt, model, vocabs_dict, optim, device_context): +def build_model_saver(model_opts, opts, model, vocabs_dict, optim, device_context): # _check_save_model_path - save_model_path = os.path.abspath(opt.save_model) + save_model_path = os.path.abspath(opts.save_model) os.makedirs(os.path.dirname(save_model_path), exist_ok=True) model_saver = ModelSaver( - opt.save_model, model, model_opt, vocabs_dict, optim, opt.keep_checkpoint, device_context, opt.save_all_gpus + opts.save_model, model, model_opts, vocabs_dict, optim, opts.keep_checkpoint, device_context, opts.save_all_gpus ) return model_saver @@ -40,7 +40,7 @@ def __init__( self, base_path, model, - model_opt, + model_opts, vocabs_dict, optim, keep_checkpoint=-1, @@ -49,7 +49,7 @@ def __init__( ): self.base_path = base_path self.model = model - self.model_opt = model_opt + self.model_opts = model_opts self.vocabs_dict = vocabs_dict self.optim = optim self.last_saved_step = None @@ -129,7 +129,7 @@ def _save(self, step, model, device_context): "model": model_state_dict, # 'generator': generator_state_dict, "vocab": self.vocabs_dict, - "opt": self.model_opt, + "opts": self.model_opts, "optim": {k: v.state_dict() for k, v in self.optim._optimizer.optimizers.items()}, "whole_model": self.model, } @@ -159,7 +159,7 @@ def _save(self, step, model, device_context): tmp_checkpoint_paths.append(checkpoint_path) if device_context.is_master(): - # TODO: not sure how to deal with model_state_dict, fields, model_opt and optim.state_dict() in a multi-gpu + # TODO: not sure how to deal with model_state_dict, fields, model_opts and optim.state_dict() in a multi-gpu # setting. Is it OK to save only from master? # model frame diff --git a/mammoth/modules/__init__.py b/mammoth/modules/__init__.py new file mode 100644 index 00000000..975b2ef6 --- /dev/null +++ b/mammoth/modules/__init__.py @@ -0,0 +1,41 @@ +"""Components libary""" +from mammoth.modules.util_class import Elementwise +from mammoth.modules.multi_headed_attn import MultiHeadedAttention +from mammoth.modules.embeddings import Embeddings, PositionalEncoding +# from mammoth.modules.weight_norm import WeightNormConv2d +from mammoth.modules.average_attn import AverageAttention +from mammoth.modules.attention_bridge import AttentionBridge + +from mammoth.modules.encoder import EncoderBase +from mammoth.modules.transformer_encoder import TransformerEncoder +from mammoth.modules.mean_encoder import MeanEncoder + +from mammoth.modules.decoder import DecoderBase +from mammoth.modules.transformer_decoder import TransformerDecoder + + +str2enc = { + "transformer": TransformerEncoder, + "mean": MeanEncoder, +} + +str2dec = { + "transformer": TransformerDecoder, +} + + +__all__ = [ + "DecoderBase", + "TransformerDecoder", + "str2dec", + "EncoderBase", + "TransformerEncoder", + "MeanEncoder", + "str2enc", + "Elementwise", + "MultiHeadedAttention", + "Embeddings", + "PositionalEncoding", + "AverageAttention", + "AttentionBridge", +] diff --git a/onmt/attention_bridge.py b/mammoth/modules/attention_bridge.py similarity index 88% rename from onmt/attention_bridge.py rename to mammoth/modules/attention_bridge.py index 481c6f8e..d6fa0592 100644 --- a/onmt/attention_bridge.py +++ b/mammoth/modules/attention_bridge.py @@ -4,10 +4,10 @@ import torch import torch.nn as nn -from onmt.rmsnorm_torch import RMSNorm -from onmt.encoders.transformer import TransformerEncoderLayer +from mammoth.rmsnorm_torch import RMSNorm +from mammoth.modules.transformer_encoder import TransformerEncoderLayer -from onmt.modules.multi_headed_attn import MultiHeadedAttention +from mammoth.modules.multi_headed_attn import MultiHeadedAttention class BaseAttentionBridgeLayer(nn.Module): @@ -73,15 +73,15 @@ def __init__( self.self_ff_norm = AttentionBridgeNorm(latent_size, norm_type) @classmethod - def from_opt(cls, opt): + def from_opts(cls, opts): return cls( - opt.rnn_size, - opt.hidden_ab_size, - opt.ab_fixed_length, - opt.heads, - opt.attention_dropout[0], - opt.max_relative_positions, - opt.ab_layer_norm, + opts.model_dim, + opts.hidden_ab_size, + opts.ab_fixed_length, + opts.heads, + opts.attention_dropout[0], + opts.max_relative_positions, + opts.ab_layer_norm, ) @property @@ -133,7 +133,7 @@ def __init__( attention_heads, hidden_ab_size, model_type, - dec_rnn_size, + model_dim, ab_layer_norm=None, ): """Attention Heads Layer:""" @@ -144,7 +144,7 @@ def __init__( self.dd = u self.model_type = model_type if self.model_type != "text": - d = dec_rnn_size + d = model_dim self.ws1 = nn.Linear(d, u, bias=True) self.ws2 = nn.Linear(u, r, bias=True) self.relu = nn.ReLU() @@ -154,15 +154,15 @@ def __init__( self.norm = AttentionBridgeNorm(d, ab_layer_norm) @classmethod - def from_opt(cls, opt): + def from_opts(cls, opts): """Alternate constructor.""" return cls( - opt.rnn_size, - opt.ab_fixed_length, - opt.hidden_ab_size, - opt.model_type, - opt.dec_rnn_size, - opt.ab_layer_norm, + opts.model_dim, + opts.ab_fixed_length, + opts.hidden_ab_size, + opts.model_type, + opts.model_dim, + opts.ab_layer_norm, ) def forward(self, intermediate_output, encoder_output, mask=None): @@ -244,12 +244,12 @@ def forward(self, intermediate_output, encoder_output, mask=None): return attention_weights, output @classmethod - def from_opt(cls, opt): + def from_opts(cls, opts): return cls( - opt.enc_rnn_size, - opt.hidden_ab_size, - opt.ab_fixed_length, - opt.ab_layer_norm, + opts.model_dim, + opts.hidden_ab_size, + opts.ab_fixed_length, + opts.ab_layer_norm, ) @@ -276,15 +276,15 @@ def forward(self, intermediate_output, encoder_output, mask=None): return None, outp @classmethod - def from_opt(cls, opt): + def from_opts(cls, opts): return cls( - opt.enc_rnn_size, - opt.heads, - opt.hidden_ab_size, # d_ff + opts.model_dim, + opts.heads, + opts.hidden_ab_size, # d_ff # TODO: that list indexing things seems suspicious to me... - opt.dropout[0], - opt.attention_dropout[0], - max_relative_positions=opt.max_relative_positions, + opts.dropout[0], + opts.attention_dropout[0], + max_relative_positions=opts.max_relative_positions, ) @@ -313,11 +313,11 @@ def forward(self, intermediate_output, encoder_output, mask=None): return None, self.module(intermediate_output) @classmethod - def from_opt(cls, opt): + def from_opts(cls, opts): return cls( - opt.enc_rnn_size, - opt.hidden_ab_size, - opt.ab_layer_norm, + opts.model_dim, + opts.hidden_ab_size, + opts.ab_layer_norm, ) @@ -333,7 +333,7 @@ def __init__(self, layers): self.is_fixed_length = any(x.is_fixed_length for x in layers) @classmethod - def from_opt(cls, opt): + def from_opts(cls, opts): """Alternate constructor.""" # convert opts specifications to architectures layer_type_to_cls = { @@ -344,16 +344,16 @@ def from_opt(cls, opt): 'feedforward': FeedForwardAttentionBridgeLayer, } - # preconstruct layers using .from_opt(...) - layers = [layer_type_to_cls[layer_type].from_opt(opt) for layer_type in opt.ab_layers] + # preconstruct layers using .from_opts(...) + layers = [layer_type_to_cls[layer_type].from_opts(opts) for layer_type in opts.ab_layers] # FIXME: locking-in edge case behavior - if any(layer == 'perceiver' for layer in opt.ab_layers): - first_perceiver_index = next(idx for idx, layer in enumerate(opt.ab_layers) if layer == 'perceiver') + if any(layer == 'perceiver' for layer in opts.ab_layers): + first_perceiver_index = next(idx for idx, layer in enumerate(opts.ab_layers) if layer == 'perceiver') if first_perceiver_index != 0: assert any(layer.is_fixed_length for layer in layers[:first_perceiver_index]), \ 'Unsupported bridge configuration: at least one layer must be fixed-size before perceiver' - if not all(layer == 'perceiver' for layer in opt.ab_layers): + if not all(layer == 'perceiver' for layer in opts.ab_layers): warnings.warn('Architecture-mixing not fully supported with perceiver.') # FIXME: deleting unused params manually for perceiver_layer in layers[1:]: diff --git a/onmt/modules/average_attn.py b/mammoth/modules/average_attn.py similarity index 97% rename from onmt/modules/average_attn.py rename to mammoth/modules/average_attn.py index 04fec83f..24c8eb23 100644 --- a/onmt/modules/average_attn.py +++ b/mammoth/modules/average_attn.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn -from onmt.modules.position_ffn import PositionwiseFeedForward -from onmt.modules.position_ffn import ActivationFunction +from mammoth.modules.position_ffn import PositionwiseFeedForward +from mammoth.modules.position_ffn import ActivationFunction class AverageAttention(nn.Module): diff --git a/mammoth/modules/decoder.py b/mammoth/modules/decoder.py new file mode 100644 index 00000000..e0e707f5 --- /dev/null +++ b/mammoth/modules/decoder.py @@ -0,0 +1,22 @@ +import torch.nn as nn + + +class DecoderBase(nn.Module): + """Abstract class for decoders. + + Args: + attentional (bool): The decoder returns non-empty attention. + """ + + def __init__(self, attentional=True): + super(DecoderBase, self).__init__() + self.attentional = attentional + + @classmethod + def from_opts(cls, opts, embeddings): + """Alternate constructor. + + Subclasses should override this method. + """ + + raise NotImplementedError diff --git a/onmt/decoders/ensemble.py b/mammoth/modules/decoder_ensemble.py similarity index 89% rename from onmt/decoders/ensemble.py rename to mammoth/modules/decoder_ensemble.py index c1b2c3b1..08248f1d 100644 --- a/onmt/decoders/ensemble.py +++ b/mammoth/modules/decoder_ensemble.py @@ -9,10 +9,10 @@ import torch import torch.nn as nn -from onmt.encoders.encoder import EncoderBase -from onmt.decoders.decoder import DecoderBase -from onmt.models import NMTModel -import onmt.model_builder +from mammoth.modules.encoder import EncoderBase +from mammoth.modules.decoder import DecoderBase +from mammoth.models import NMTModel +import mammoth.model_builder class EnsembleDecoderOutput(object): @@ -23,7 +23,7 @@ def __init__(self, model_dec_outs): def squeeze(self, dim=None): """Delegate squeeze to avoid modifying - :func:`onmt.translate.translator.Translator.translate_batch()` + :func:`mammoth.translate.translator.Translator.translate_batch()` """ return EnsembleDecoderOutput([x.squeeze(dim) for x in self.model_dec_outs]) @@ -53,7 +53,7 @@ def __init__(self, model_decoders): self.model_decoders = model_decoders def forward(self, tgt, memory_bank, memory_lengths=None, step=None, **kwargs): - """See :func:`onmt.decoders.decoder.DecoderBase.forward()`.""" + """See :func:`mammoth.decoders.decoder.DecoderBase.forward()`.""" # Memory_lengths is a single tensor shared between all models. # This assumption will not hold if Translator is modified # to calculate memory_lengths as something other than the length @@ -120,13 +120,13 @@ def __init__(self, models, raw_probs=False): self.models = nn.ModuleList(models) -def load_test_model(opt): +def load_test_model(opts): """Read in multiple models for ensemble.""" shared_vocabs = None shared_model_opt = None models = [] - for model_path in opt.models: - vocabs, model, model_opt = onmt.model_builder.load_test_multitask_model(opt, model_path=model_path) + for model_path in opts.models: + vocabs, model, model_opts = mammoth.model_builder.load_test_multitask_model(opts, model_path=model_path) if shared_vocabs is None: shared_vocabs = vocabs else: @@ -137,6 +137,6 @@ def load_test_model(opt): # assert vocab.stoi == sh_vocab.stoi, "Ensemble models must use the same preprocessed data" models.append(model) if shared_model_opt is None: - shared_model_opt = model_opt - ensemble_model = EnsembleModel(models, opt.avg_raw_probs) + shared_model_opt = model_opts + ensemble_model = EnsembleModel(models, opts.avg_raw_probs) return shared_vocabs, ensemble_model, shared_model_opt diff --git a/onmt/modules/embeddings.py b/mammoth/modules/embeddings.py similarity index 89% rename from onmt/modules/embeddings.py rename to mammoth/modules/embeddings.py index ce8f7732..dbd16ab1 100644 --- a/onmt/modules/embeddings.py +++ b/mammoth/modules/embeddings.py @@ -5,11 +5,10 @@ import torch import torch.nn as nn -from onmt.modules.util_class import Elementwise -# from onmt.utils.logging import logger +from mammoth.modules.util_class import Elementwise +# from mammoth.utils.logging import logger # import bitsandbytes as bnb -# from onmt.modules.stable_embeddings import StableEmbedding class SequenceTooLongError(Exception): @@ -66,7 +65,7 @@ def forward(self, emb, step=None): class Embeddings(nn.Module): """Words embeddings for encoder/decoder. - Additionally includes ability to add sparse input features + Additionally includes ability to add input features based on "Linguistic Input Features Improve Neural Machine Translation" :cite:`sennrich2016linguistic`. @@ -92,7 +91,7 @@ class Embeddings(nn.Module): word_vocab_size (int): size of dictionary of embeddings for words. feat_vocab_sizes (List[int], optional): list of size of dictionary of embeddings for each feature. - position_encoding (bool): see :class:`~onmt.modules.PositionalEncoding` + position_encoding (bool): see :class:`~mammoth.modules.PositionalEncoding` feat_merge (string): merge action for the features embeddings: concat, sum or mlp. feat_vec_exponent (float): when using `-feat_merge concat`, feature @@ -116,7 +115,6 @@ def __init__( feat_padding_idx=[], feat_vocab_sizes=[], dropout=0, - sparse=False, freeze_word_vecs=False, ): self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent, feat_vec_size, feat_padding_idx) @@ -147,7 +145,7 @@ def __init__( # The embedding matrix look-up tables. The first look-up table # is for words. Subsequent ones are for features, if any exist. emb_params = zip(vocab_sizes, emb_dims, pad_indices) - embeddings = [nn.Embedding(vocab, dim, padding_idx=pad, sparse=sparse) for vocab, dim, pad in emb_params] + embeddings = [nn.Embedding(vocab, dim, padding_idx=pad) for vocab, dim, pad in emb_params] emb_luts = Elementwise(feat_merge, embeddings) # The final output size of word + feature vectors. This can vary @@ -340,12 +338,12 @@ def convert_to_torch_tensor(word_to_float_list_dict, vocab): return tensor # FIXME: seems it got nuked during the great refactoring of data -# def prepare_pretrained_embeddings(opt, fields): -# if all([opt.both_embeddings is None, opt.src_embeddings is None, opt.tgt_embeddings is None]): +# def prepare_pretrained_embeddings(opts, fields): +# if all([opts.both_embeddings is None, opts.src_embeddings is None, opts.tgt_embeddings is None]): # return # # assert ( -# opt.save_data +# opts.save_data # ), "-save_data is required when using \ # pretrained embeddings." # @@ -358,42 +356,42 @@ def convert_to_torch_tensor(word_to_float_list_dict, vocab): # vocs.append(vocab) # enc_vocab, dec_vocab = vocs # -# skip_lines = 1 if opt.embeddings_type == "word2vec" else 0 -# if opt.both_embeddings is not None: +# skip_lines = 1 if opts.embeddings_type == "word2vec" else 0 +# if opts.both_embeddings is not None: # set_of_src_and_tgt_vocab = set(enc_vocab.stoi.keys()) | set(dec_vocab.stoi.keys()) -# logger.info("Reading encoder and decoder embeddings from {}".format(opt.both_embeddings)) -# src_vectors, total_vec_count = read_embeddings(opt.both_embeddings, skip_lines, set_of_src_and_tgt_vocab) +# logger.info("Reading encoder and decoder embeddings from {}".format(opts.both_embeddings)) +# src_vectors, total_vec_count = read_embeddings(opts.both_embeddings, skip_lines, set_of_src_and_tgt_vocab) # tgt_vectors = src_vectors # logger.info("\tFound {} total vectors in file".format(total_vec_count)) # else: -# if opt.src_embeddings is not None: -# logger.info("Reading encoder embeddings from {}".format(opt.src_embeddings)) -# src_vectors, total_vec_count = read_embeddings(opt.src_embeddings, skip_lines, filter_set=enc_vocab.stoi) +# if opts.src_embeddings is not None: +# logger.info("Reading encoder embeddings from {}".format(opts.src_embeddings)) +# src_vectors, total_vec_count = read_embeddings(opts.src_embeddings, skip_lines, filter_set=enc_vocab.stoi) # logger.info("\tFound {} total vectors in file.".format(total_vec_count)) # else: # src_vectors = None -# if opt.tgt_embeddings is not None: -# logger.info("Reading decoder embeddings from {}".format(opt.tgt_embeddings)) -# tgt_vectors, total_vec_count = read_embeddings(opt.tgt_embeddings, skip_lines, filter_set=dec_vocab.stoi) +# if opts.tgt_embeddings is not None: +# logger.info("Reading decoder embeddings from {}".format(opts.tgt_embeddings)) +# tgt_vectors, total_vec_count = read_embeddings(opts.tgt_embeddings, skip_lines, filter_set=dec_vocab.stoi) # logger.info("\tFound {} total vectors in file".format(total_vec_count)) # else: # tgt_vectors = None # logger.info("After filtering to vectors in vocab:") -# if opt.src_embeddings is not None or opt.both_embeddings is not None: +# if opts.src_embeddings is not None or opts.both_embeddings is not None: # logger.info("\t* enc: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(enc_vocab, src_vectors)) -# if opt.tgt_embeddings is not None or opt.both_embeddings is not None: +# if opts.tgt_embeddings is not None or opts.both_embeddings is not None: # logger.info("\t* dec: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(dec_vocab, tgt_vectors)) # # # Write to file -# enc_output_file = opt.save_data + ".enc_embeddings.pt" -# dec_output_file = opt.save_data + ".dec_embeddings.pt" -# if opt.src_embeddings is not None or opt.both_embeddings is not None: +# enc_output_file = opts.save_data + ".enc_embeddings.pt" +# dec_output_file = opts.save_data + ".dec_embeddings.pt" +# if opts.src_embeddings is not None or opts.both_embeddings is not None: # logger.info("\nSaving encoder embeddings as:\n\t* enc: %s" % enc_output_file) # torch.save(convert_to_torch_tensor(src_vectors, enc_vocab), enc_output_file) -# # set the opt in place -# opt.pre_word_vecs_enc = enc_output_file -# if opt.tgt_embeddings is not None or opt.both_embeddings is not None: +# # set the opts in place +# opts.pre_word_vecs_enc = enc_output_file +# if opts.tgt_embeddings is not None or opts.both_embeddings is not None: # logger.info("\nSaving decoder embeddings as:\n\t* dec: %s" % dec_output_file) # torch.save(convert_to_torch_tensor(tgt_vectors, dec_vocab), dec_output_file) -# # set the opt in place -# opt.pre_word_vecs_dec = dec_output_file +# # set the opts in place +# opts.pre_word_vecs_dec = dec_output_file diff --git a/onmt/encoders/encoder.py b/mammoth/modules/encoder.py similarity index 90% rename from onmt/encoders/encoder.py rename to mammoth/modules/encoder.py index 71db9cad..9e55792b 100644 --- a/onmt/encoders/encoder.py +++ b/mammoth/modules/encoder.py @@ -2,13 +2,13 @@ import torch.nn as nn -from onmt.utils.misc import aeq +from mammoth.utils.misc import aeq class EncoderBase(nn.Module): """ Base encoder class. Specifies the interface used by different encoder types - and required by :class:`onmt.Models.NMTModel`. + and required by :class:`mammoth.Models.NMTModel`. .. mermaid:: @@ -31,7 +31,7 @@ class EncoderBase(nn.Module): """ @classmethod - def from_opt(cls, opt, embeddings=None): + def from_opts(cls, opts, embeddings=None): raise NotImplementedError def _check_args(self, src, lengths=None, hidden=None): diff --git a/onmt/decoders/layer_stack_decoder.py b/mammoth/modules/layer_stack_decoder.py similarity index 75% rename from onmt/decoders/layer_stack_decoder.py rename to mammoth/modules/layer_stack_decoder.py index 5fc7f594..a2136889 100644 --- a/onmt/decoders/layer_stack_decoder.py +++ b/mammoth/modules/layer_stack_decoder.py @@ -2,9 +2,9 @@ from torch import nn from typing import Dict, List -from onmt.decoders.decoder import DecoderBase -from onmt.models.adapters import Adapter, AdaptedTransformerDecoder -from onmt.utils.distributed import DatasetMetadata +from mammoth.modules.decoder import DecoderBase +from mammoth.models.adapters import Adapter, AdaptedTransformerDecoder +from mammoth.distributed import DatasetMetadata class LayerStackDecoder(DecoderBase): @@ -17,11 +17,11 @@ def __init__(self, embeddings, decoders): self._active: List[str] = [] @classmethod - def from_opt(cls, opt, embeddings, task_queue_manager, is_on_top=False): + def from_opts(cls, opts, embeddings, task_queue_manager, is_on_top=False): """Alternate constructor for use during training.""" decoders = nn.ModuleList() - for layer_stack_index, n_layers in enumerate(opt.dec_layers): - is_on_top = layer_stack_index == len(opt.dec_layers) - 1 + for layer_stack_index, n_layers in enumerate(opts.dec_layers): + is_on_top = layer_stack_index == len(opts.dec_layers) - 1 stacks = nn.ModuleDict() for module_id in task_queue_manager.get_decoders(layer_stack_index): if module_id in stacks: @@ -29,26 +29,26 @@ def from_opt(cls, opt, embeddings, task_queue_manager, is_on_top=False): continue stacks[module_id] = AdaptedTransformerDecoder( n_layers, - opt.dec_rnn_size, - opt.heads, - opt.transformer_ff, - opt.copy_attn, - opt.self_attn_type, - opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + opts.model_dim, + opts.heads, + opts.transformer_ff, + opts.copy_attn, + opts.self_attn_type, + opts.dropout[0] if type(opts.dropout) is list else opts.dropout, ( - opt.attention_dropout[0] - if type(opt.attention_dropout) is list - else opt.attention_dropout + opts.attention_dropout[0] + if type(opts.attention_dropout) is list + else opts.attention_dropout ), None, # embeddings, - opt.max_relative_positions, - opt.aan_useffn, - opt.full_context_alignment, - opt.alignment_layer, - alignment_heads=opt.alignment_heads, - pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + opts.max_relative_positions, + opts.aan_useffn, + opts.full_context_alignment, + opts.alignment_layer, + alignment_heads=opts.alignment_heads, + pos_ffn_activation_fn=opts.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top + nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top else nn.Identity() ), ) @@ -56,36 +56,36 @@ def from_opt(cls, opt, embeddings, task_queue_manager, is_on_top=False): return cls(embeddings, decoders) @classmethod - def from_trans_opt(cls, model_opt, embeddings, opt_stack): + def from_trans_opt(cls, model_opts, embeddings, opt_stack): """Alternate constructor for use during translation.""" decoders = nn.ModuleList() - for layer_stack_index, n_layers in enumerate(model_opt.dec_layers): + for layer_stack_index, n_layers in enumerate(model_opts.dec_layers): stacks = nn.ModuleDict() - is_on_top = layer_stack_index == len(model_opt.dec_layers) - 1 + is_on_top = layer_stack_index == len(model_opts.dec_layers) - 1 module_opts = opt_stack['decoder'][layer_stack_index] module_id = module_opts['id'] stacks[module_id] = AdaptedTransformerDecoder( n_layers, - model_opt.dec_rnn_size, - model_opt.heads, - model_opt.transformer_ff, - model_opt.copy_attn, - model_opt.self_attn_type, - model_opt.dropout[0] if type(model_opt.dropout) is list else model_opt.dropout, + model_opts.model_dim, + model_opts.heads, + model_opts.transformer_ff, + model_opts.copy_attn, + model_opts.self_attn_type, + model_opts.dropout[0] if type(model_opts.dropout) is list else model_opts.dropout, ( - model_opt.attention_dropout[0] - if type(model_opt.attention_dropout) is list - else model_opt.attention_dropout + model_opts.attention_dropout[0] + if type(model_opts.attention_dropout) is list + else model_opts.attention_dropout ), None, # embeddings, - model_opt.max_relative_positions, - model_opt.aan_useffn, - model_opt.full_context_alignment, - model_opt.alignment_layer, - alignment_heads=model_opt.alignment_heads, - pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn, + model_opts.max_relative_positions, + model_opts.aan_useffn, + model_opts.full_context_alignment, + model_opts.alignment_layer, + alignment_heads=model_opts.alignment_heads, + pos_ffn_activation_fn=model_opts.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(model_opt.dec_rnn_size, eps=1e-6) if is_on_top + nn.LayerNorm(model_opts.model_dim, eps=1e-6) if is_on_top else nn.Identity() ), ) diff --git a/onmt/encoders/layer_stack_encoder.py b/mammoth/modules/layer_stack_encoder.py similarity index 76% rename from onmt/encoders/layer_stack_encoder.py rename to mammoth/modules/layer_stack_encoder.py index 77073fd9..a8de6dd4 100644 --- a/onmt/encoders/layer_stack_encoder.py +++ b/mammoth/modules/layer_stack_encoder.py @@ -1,10 +1,10 @@ from torch import nn from typing import Dict, List -from onmt.encoders.encoder import EncoderBase -from onmt.models.adapters import Adapter, AdaptedTransformerEncoder -from onmt.utils.misc import sequence_mask -from onmt.utils.distributed import DatasetMetadata +from mammoth.modules.encoder import EncoderBase +from mammoth.models.adapters import Adapter, AdaptedTransformerEncoder +from mammoth.utils.misc import sequence_mask +from mammoth.distributed import DatasetMetadata class LayerStackEncoder(EncoderBase): @@ -17,32 +17,32 @@ def __init__(self, embeddings, encoders): self._active: List[str] = [] @classmethod - def from_opt(cls, opt, embeddings, task_queue_manager): + def from_opts(cls, opts, embeddings, task_queue_manager): """Alternate constructor for use during training.""" encoders = nn.ModuleList() - for layer_stack_index, n_layers in enumerate(opt.enc_layers): + for layer_stack_index, n_layers in enumerate(opts.enc_layers): stacks = nn.ModuleDict() - is_on_top = layer_stack_index == len(opt.enc_layers) - 1 + is_on_top = layer_stack_index == len(opts.enc_layers) - 1 for module_id in task_queue_manager.get_encoders(layer_stack_index): if module_id in stacks: # several tasks using the same layer stack continue stacks[module_id] = AdaptedTransformerEncoder( n_layers, - opt.enc_rnn_size, - opt.heads, - opt.transformer_ff, - opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + opts.model_dim, + opts.heads, + opts.transformer_ff, + opts.dropout[0] if type(opts.dropout) is list else opts.dropout, ( - opt.attention_dropout[0] - if type(opt.attention_dropout) is list - else opt.attention_dropout + opts.attention_dropout[0] + if type(opts.attention_dropout) is list + else opts.attention_dropout ), None, # embeddings, - opt.max_relative_positions, - pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + opts.max_relative_positions, + pos_ffn_activation_fn=opts.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(opt.enc_rnn_size, eps=1e-6) if is_on_top + nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top else nn.Identity() ) ) @@ -50,30 +50,30 @@ def from_opt(cls, opt, embeddings, task_queue_manager): return cls(embeddings, encoders) @classmethod - def from_trans_opt(cls, model_opt, embeddings, opt_stack): + def from_trans_opt(cls, model_opts, embeddings, opt_stack): """Alternate constructor for use during translation.""" encoders = nn.ModuleList() - for layer_stack_index, n_layers in enumerate(model_opt.enc_layers): + for layer_stack_index, n_layers in enumerate(model_opts.enc_layers): stacks = nn.ModuleDict() module_opts = opt_stack['encoder'][layer_stack_index] module_id = module_opts['id'] - is_on_top = layer_stack_index == len(model_opt.enc_layers) - 1 + is_on_top = layer_stack_index == len(model_opts.enc_layers) - 1 stacks[module_id] = AdaptedTransformerEncoder( n_layers, - model_opt.enc_rnn_size, - model_opt.heads, - model_opt.transformer_ff, - model_opt.dropout[0] if type(model_opt.dropout) is list else model_opt.dropout, + model_opts.model_dim, + model_opts.heads, + model_opts.transformer_ff, + model_opts.dropout[0] if type(model_opts.dropout) is list else model_opts.dropout, ( - model_opt.attention_dropout[0] - if type(model_opt.attention_dropout) is list - else model_opt.attention_dropout + model_opts.attention_dropout[0] + if type(model_opts.attention_dropout) is list + else model_opts.attention_dropout ), None, # embeddings, - model_opt.max_relative_positions, - pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn, + model_opts.max_relative_positions, + pos_ffn_activation_fn=model_opts.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(model_opt.enc_rnn_size, eps=1e-6) if is_on_top + nn.LayerNorm(model_opts.model_dim, eps=1e-6) if is_on_top else nn.Identity() ) ) diff --git a/onmt/encoders/mean_encoder.py b/mammoth/modules/mean_encoder.py similarity index 81% rename from onmt/encoders/mean_encoder.py rename to mammoth/modules/mean_encoder.py index ca903c99..943a099e 100644 --- a/onmt/encoders/mean_encoder.py +++ b/mammoth/modules/mean_encoder.py @@ -1,6 +1,6 @@ """Define a minimal encoder.""" -from onmt.encoders.encoder import EncoderBase -from onmt.utils.misc import sequence_mask +from mammoth.modules.encoder import EncoderBase +from mammoth.utils.misc import sequence_mask import torch @@ -9,7 +9,7 @@ class MeanEncoder(EncoderBase): Args: num_layers (int): number of replicated layers - embeddings (onmt.modules.Embeddings): embedding module to use + embeddings (mammoth.modules.Embeddings): embedding module to use """ def __init__(self, num_layers, embeddings): @@ -18,9 +18,9 @@ def __init__(self, num_layers, embeddings): self.embeddings = embeddings @classmethod - def from_opt(cls, opt, embeddings): + def from_opts(cls, opts, embeddings): """Alternate constructor.""" - return cls(opt.enc_layers, embeddings) + return cls(opts.enc_layers, embeddings) def forward(self, src, lengths=None): """See :func:`EncoderBase.forward()`""" diff --git a/onmt/modules/multi_headed_attn.py b/mammoth/modules/multi_headed_attn.py similarity index 98% rename from onmt/modules/multi_headed_attn.py rename to mammoth/modules/multi_headed_attn.py index d44912b1..1a4b3028 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/mammoth/modules/multi_headed_attn.py @@ -3,9 +3,9 @@ import torch import torch.nn as nn -from onmt.utils.misc import generate_relative_positions_matrix, relative_matmul +from mammoth.utils.misc import generate_relative_positions_matrix, relative_matmul -# from onmt.utils.misc import aeq +# from mammoth.utils.misc import aeq class MultiHeadedAttention(nn.Module): diff --git a/onmt/modules/position_ffn.py b/mammoth/modules/position_ffn.py similarity index 100% rename from onmt/modules/position_ffn.py rename to mammoth/modules/position_ffn.py diff --git a/onmt/decoders/transformer.py b/mammoth/modules/transformer_decoder.py similarity index 94% rename from onmt/decoders/transformer.py rename to mammoth/modules/transformer_decoder.py index 637db455..a52d7a0a 100644 --- a/onmt/decoders/transformer.py +++ b/mammoth/modules/transformer_decoder.py @@ -6,11 +6,11 @@ import torch import torch.nn as nn -from onmt.decoders.decoder import DecoderBase -from onmt.modules import MultiHeadedAttention, AverageAttention -from onmt.modules.position_ffn import PositionwiseFeedForward -from onmt.modules.position_ffn import ActivationFunction -from onmt.utils.misc import sequence_mask +from mammoth.modules.decoder import DecoderBase +from mammoth.modules import MultiHeadedAttention, AverageAttention +from mammoth.modules.position_ffn import PositionwiseFeedForward +from mammoth.modules.position_ffn import ActivationFunction +from mammoth.utils.misc import sequence_mask class TransformerDecoderLayerBase(nn.Module): @@ -283,26 +283,26 @@ def __init__(self, d_model, copy_attn, embeddings, alignment_layer, layer_norm_m self.alignment_layer = alignment_layer @classmethod - def from_opt(cls, opt, embeddings, is_on_top=False): + def from_opts(cls, opts, embeddings, is_on_top=False): """Alternate constructor.""" return cls( - opt.dec_layers, - opt.dec_rnn_size, - opt.heads, - opt.transformer_ff, - opt.copy_attn, - opt.self_attn_type, - opt.dropout[0] if type(opt.dropout) is list else opt.dropout, - opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout, + opts.dec_layers, + opts.model_dim, + opts.heads, + opts.transformer_ff, + opts.copy_attn, + opts.self_attn_type, + opts.dropout[0] if type(opts.dropout) is list else opts.dropout, + opts.attention_dropout[0] if type(opts.attention_dropout) is list else opts.attention_dropout, embeddings, - opt.max_relative_positions, - opt.aan_useffn, - opt.full_context_alignment, - opt.alignment_layer, - alignment_heads=opt.alignment_heads, - pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + opts.max_relative_positions, + opts.aan_useffn, + opts.full_context_alignment, + opts.alignment_layer, + alignment_heads=opts.alignment_heads, + pos_ffn_activation_fn=opts.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top + nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top else nn.Identity() ), ) @@ -366,7 +366,7 @@ class TransformerDecoder(TransformerDecoderBase): self_attn_type (str): type of self-attention scaled-dot, average dropout (float): dropout in residual, self-attn(dot) and feed-forward attention_dropout (float): dropout in context_attn (and self-attn(avg)) - embeddings (onmt.modules.Embeddings): + embeddings (mammoth.modules.Embeddings): embeddings to use, should have positional encodings max_relative_positions (int): Max distance between inputs in relative positions representations diff --git a/onmt/encoders/transformer.py b/mammoth/modules/transformer_encoder.py similarity index 85% rename from onmt/encoders/transformer.py rename to mammoth/modules/transformer_encoder.py index c020aa7d..43fd92ca 100644 --- a/onmt/encoders/transformer.py +++ b/mammoth/modules/transformer_encoder.py @@ -4,11 +4,11 @@ import torch.nn as nn -from onmt.encoders.encoder import EncoderBase -from onmt.modules import MultiHeadedAttention -from onmt.modules.position_ffn import PositionwiseFeedForward -from onmt.modules.position_ffn import ActivationFunction -from onmt.utils.misc import sequence_mask +from mammoth.modules.encoder import EncoderBase +from mammoth.modules import MultiHeadedAttention +from mammoth.modules.position_ffn import PositionwiseFeedForward +from mammoth.modules.position_ffn import ActivationFunction +from mammoth.utils.misc import sequence_mask class TransformerEncoderLayer(nn.Module): @@ -88,7 +88,7 @@ class TransformerEncoder(EncoderBase): heads (int): number of heads d_ff (int): size of the inner FF layer dropout (float): dropout parameters - embeddings (onmt.modules.Embeddings): + embeddings (mammoth.modules.Embeddings): embeddings to use, should have positional encodings pos_ffn_activation_fn (ActivationFunction): activation function choice for PositionwiseFeedForward layer @@ -133,20 +133,20 @@ def __init__( self.layer_norm = layer_norm_module @classmethod - def from_opt(cls, opt, embeddings, is_on_top=False): + def from_opts(cls, opts, embeddings, is_on_top=False): """Alternate constructor.""" return cls( - opt.enc_layers, - opt.enc_rnn_size, - opt.heads, - opt.transformer_ff, - opt.dropout[0] if type(opt.dropout) is list else opt.dropout, - opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout, + opts.enc_layers, + opts.model_dim, + opts.heads, + opts.transformer_ff, + opts.dropout[0] if type(opts.dropout) is list else opts.dropout, + opts.attention_dropout[0] if type(opts.attention_dropout) is list else opts.attention_dropout, embeddings, - opt.max_relative_positions, - pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + opts.max_relative_positions, + pos_ffn_activation_fn=opts.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(opt.enc_rnn_size, eps=1e-6) if is_on_top + nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top else nn.Identity() ) ) diff --git a/onmt/modules/util_class.py b/mammoth/modules/util_class.py similarity index 100% rename from onmt/modules/util_class.py rename to mammoth/modules/util_class.py diff --git a/onmt/opts.py b/mammoth/opts.py similarity index 93% rename from onmt/opts.py rename to mammoth/opts.py index 695f8977..22a3dfd6 100644 --- a/onmt/opts.py +++ b/mammoth/opts.py @@ -1,12 +1,11 @@ """ Implementation of all available options """ import configargparse -from onmt.constants import ModelTask -from onmt.models.sru import CheckSRU -from onmt.modules.position_ffn import ACTIVATION_FUNCTIONS -from onmt.modules.position_ffn import ActivationFunction -from onmt.transforms import AVAILABLE_TRANSFORMS -from onmt.utils.distributed import TASK_DISTRIBUTION_STRATEGIES +from mammoth.constants import ModelTask +from mammoth.modules.position_ffn import ACTIVATION_FUNCTIONS +from mammoth.modules.position_ffn import ActivationFunction +from mammoth.transforms import AVAILABLE_TRANSFORMS +from mammoth.distributed import TASK_DISTRIBUTION_STRATEGIES def config_opts(parser): @@ -56,7 +55,7 @@ def _add_logging_opts(parser, is_train=True): "--tensorboard_log_dir", "-tensorboard_log_dir", type=str, - default="runs/onmt", + default="runs/mammoth", help="Log directory for Tensorboard. This is also the name of the run.", ) group.add( @@ -94,10 +93,10 @@ def _add_reproducibility_opts(parser): def _add_dynamic_corpus_opts(parser, build_vocab_only=False): """Options related to training corpus, type: a list of dictionary.""" - group = parser.add_argument_group('Data') + group = parser.add_argument_group('Data/Tasks') group.add( - "-data", - "--data", + "-tasks", + "--tasks", required=True, help="List of datasets and their specifications. See examples/*.yaml for further details.", ) @@ -274,7 +273,7 @@ def _add_dynamic_transform_opts(parser): """Options related to transforms. Options that specified in the definitions of each transform class - at `onmt/transforms/*.py`. + at `mammoth/transforms/*.py`. """ for name, transform_cls in AVAILABLE_TRANSFORMS.items(): transform_cls.add_options(parser) @@ -285,7 +284,7 @@ def dynamic_prepare_opts(parser, build_vocab_only=False): Add all dynamic data prepare related options to parser. If `build_vocab_only` set to True, then only contains options that - will be used in `onmt/bin/build_vocab.py`. + will be used in `mammoth/bin/build_vocab.py`. """ config_opts(parser) _add_dynamic_corpus_opts(parser, build_vocab_only=build_vocab_only) @@ -305,9 +304,6 @@ def model_opts(parser): # Embedding Options group = parser.add_argument_group('Model-Embeddings') - group.add('--src_word_vec_size', '-src_word_vec_size', type=int, default=500, help='Word embedding size for src.') - group.add('--tgt_word_vec_size', '-tgt_word_vec_size', type=int, default=500, help='Word embedding size for tgt.') - group.add('--word_vec_size', '-word_vec_size', type=int, default=-1, help='Word embedding size for src and tgt.') group.add( '--share_decoder_embeddings', @@ -386,43 +382,42 @@ def model_opts(parser): '--encoder_type', '-encoder_type', type=str, - default='rnn', - choices=['rnn', 'brnn', 'ggnn', 'mean', 'transformer', 'cnn', 'transformer_lm'], + default='transformer', + choices=['mean', 'transformer'], help="Type of encoder layer to use. Non-RNN layers " "are experimental. Options are " - "[rnn|brnn|ggnn|mean|transformer|cnn|transformer_lm].", + "[mean|transformer].", ) group.add( '--decoder_type', '-decoder_type', type=str, - default='rnn', - choices=['rnn', 'transformer', 'cnn', 'transformer_lm'], + default='transformer', + choices=['transformer'], help="Type of decoder layer to use. Non-RNN layers " "are experimental. Options are " - "[rnn|transformer|cnn|transformer].", + "[transformer].", ) group.add('--layers', '-layers', type=int, default=-1, help='Deprecated') group.add('--enc_layers', '-enc_layers', nargs='+', type=int, help='Number of layers in each encoder') group.add('--dec_layers', '-dec_layers', nargs='+', type=int, help='Number of layers in each decoder') group.add( - '--rnn_size', - '-rnn_size', + '--model_dim', + '-model_dim', type=int, default=-1, - help="Size of rnn hidden states. Overwrites enc_rnn_size and dec_rnn_size", - ) - group.add('--enc_rnn_size', '-enc_rnn_size', type=int, default=500, help="Size of encoder rnn hidden states.") - group.add('--dec_rnn_size', '-dec_rnn_size', type=int, default=500, help="Size of decoder rnn hidden states.") - group.add( - '--cnn_kernel_width', - '-cnn_kernel_width', - type=int, - default=3, - help="Size of windows in the cnn, the kernel_size is (cnn_kernel_width, 1) in conv layer", + help="Size of rnn hidden states.", ) + # group.add( + # '--cnn_kernel_width', + # '-cnn_kernel_width', + # type=int, + # default=3, + # help="Size of windows in the cnn, the kernel_size is (cnn_kernel_width, 1) in conv layer", + # ) + group.add( '--pos_ffn_activation_fn', '-pos_ffn_activation_fn', @@ -435,43 +430,32 @@ def model_opts(parser): f' {ActivationFunction.relu}.', ) - group.add( - '--input_feed', - '-input_feed', - type=int, - default=1, - help="Feed the context vector at each time step as " - "additional input (via concatenation with the word " - "embeddings) to the decoder.", - ) + # group.add( + # '--input_feed', + # '-input_feed', + # type=int, + # default=1, + # help="Feed the context vector at each time step as " + # "additional input (via concatenation with the word " + # "embeddings) to the decoder.", + # ) group.add( '--bridge', '-bridge', action="store_true", help="Have an additional layer between the last encoder state and the first decoder state", ) - group.add( - '--rnn_type', - '-rnn_type', - type=str, - default='LSTM', - choices=['LSTM', 'GRU', 'SRU'], - action=CheckSRU, - help="The gate type to use in the RNNs", - ) # group.add('--residual', '-residual', action="store_true", # help="Add residual connections between RNN layers.") - group.add('--brnn', '-brnn', action=DeprecateAction, help="Deprecated, use `encoder_type`.") - - group.add( - '--context_gate', - '-context_gate', - type=str, - default=None, - choices=['source', 'target', 'both'], - help="Type of context gate to use. Do not select for no context gate.", - ) + # group.add( + # '--context_gate', + # '-context_gate', + # type=str, + # default=None, + # choices=['source', 'target', 'both'], + # help="Type of context gate to use. Do not select for no context gate.", + # ) # The following options (bridge_extra_node to n_steps) are used # for training with --encoder_type ggnn (Gated Graph Neural Network). @@ -514,7 +498,7 @@ def model_opts(parser): '-global_attention_function', type=str, default="softmax", - choices=["softmax", "sparsemax"], + choices=["softmax"], ) group.add( '--self_attn_type', @@ -582,10 +566,10 @@ def model_opts(parser): '--generator_function', '-generator_function', default="softmax", - choices=["softmax", "sparsemax"], + choices=["softmax"], help="Which function to use for generating " "probabilities over the target vocabulary (choices: " - "softmax, sparsemax)", + "softmax)", ) group.add('--copy_attn_force', '-copy_attn_force', action="store_true", help='When available, train to copy.') group.add('--reuse_copy_attn', '-reuse_copy_attn', action="store_true", help="Reuse standard attention for copy") @@ -616,7 +600,7 @@ def model_opts(parser): type=str, default="O1", choices=["O0", "O1", "O2", "O3"], - help="For FP16 training, the opt_level to use. See https://nvidia.github.io/apex/amp.html#opt-levels.", + help="For FP16 training, the opt_level to use. See https://nvidia.github.io/apex/amp.html#opts-levels.", ) # attention bridge options @@ -857,7 +841,7 @@ def _add_train_general_opts(parser): '--optim', '-optim', default='sgd', - choices=['sgd', 'adagrad', 'adadelta', 'adam', 'adamw', 'sparseadam', 'adafactor', 'fusedadam'], + choices=['sgd', 'adagrad', 'adadelta', 'adam', 'adamw', 'adafactor', 'fusedadam'], help="Optimization method.", ) group.add( @@ -1009,8 +993,8 @@ def _add_train_general_opts(parser): def _add_train_dynamic_data(parser): group = parser.add_argument_group("Dynamic data") group.add( - "-bucket_size", - "--bucket_size", + "-pool_size", + "--pool_size", type=int, default=2048, help="Number of examples to dynamically pool before batching.", @@ -1220,7 +1204,7 @@ def translate_opts(parser, dynamic=False): help="Divide src and tgt (if applicable) into " "smaller multiple src and tgt files, then " "build shards, each shard will have " - "opt.shard_size samples except last shard. " + "opts.shard_size samples except last shard. " "shard_size=0 means no segmentation " "shard_size>0 means segment dataset into multiple shards, " "each shard has shard_size samples", diff --git a/onmt/rmsnorm_torch.py b/mammoth/rmsnorm_torch.py similarity index 100% rename from onmt/rmsnorm_torch.py rename to mammoth/rmsnorm_torch.py diff --git a/onmt/tests/__init__.py b/mammoth/tests/__init__.py similarity index 100% rename from onmt/tests/__init__.py rename to mammoth/tests/__init__.py diff --git a/onmt/tests/output_hyp.txt b/mammoth/tests/output_hyp.txt similarity index 100% rename from onmt/tests/output_hyp.txt rename to mammoth/tests/output_hyp.txt diff --git a/onmt/tests/pull_request_chk.sh b/mammoth/tests/pull_request_chk.sh similarity index 100% rename from onmt/tests/pull_request_chk.sh rename to mammoth/tests/pull_request_chk.sh diff --git a/onmt/tests/rebuild_test_models.sh b/mammoth/tests/rebuild_test_models.sh similarity index 100% rename from onmt/tests/rebuild_test_models.sh rename to mammoth/tests/rebuild_test_models.sh diff --git a/onmt/tests/sample_glove.txt b/mammoth/tests/sample_glove.txt similarity index 100% rename from onmt/tests/sample_glove.txt rename to mammoth/tests/sample_glove.txt diff --git a/onmt/tests/test_beam_search.py b/mammoth/tests/test_beam_search.py similarity index 99% rename from onmt/tests/test_beam_search.py rename to mammoth/tests/test_beam_search.py index 8a6e54a4..f43dd134 100644 --- a/onmt/tests/test_beam_search.py +++ b/mammoth/tests/test_beam_search.py @@ -1,6 +1,6 @@ import unittest -from onmt.translate.beam_search import BeamSearch, GNMTGlobalScorer -from onmt.translate.beam_search import BeamSearchLM +from mammoth.translate.beam_search import BeamSearch, GNMTGlobalScorer +from mammoth.translate.beam_search import BeamSearchLM from copy import deepcopy diff --git a/onmt/tests/test_data_prepare.py b/mammoth/tests/test_data_prepare.py similarity index 79% rename from onmt/tests/test_data_prepare.py rename to mammoth/tests/test_data_prepare.py index 9c7fdffb..d1780982 100644 --- a/onmt/tests/test_data_prepare.py +++ b/mammoth/tests/test_data_prepare.py @@ -8,10 +8,10 @@ # import glob # import os # -# from onmt.utils.parse import ArgumentParser -# from onmt.opts import dynamic_prepare_opts -# from onmt.bin.train import prepare_fields_transforms -# from onmt.constants import CorpusName +# from mammoth.utils.parse import ArgumentParser +# from mammoth.opts import dynamic_prepare_opts +# from mammoth.bin.train import prepare_fields_transforms +# from mammoth.constants import CorpusName # # # SAVE_DATA_PREFIX = 'data/test_data_prepare' @@ -27,11 +27,11 @@ # '-tgt_vocab', 'data/vocab-train.tgt' # ] # -# opt = parser.parse_known_args(default_opts)[0] +# opts = parser.parse_known_args(default_opts)[0] # # Inject some dummy training options that may needed when build fields -# opt.copy_attn = False -# ArgumentParser.validate_prepare_opts(opt) -# return opt +# opts.copy_attn = False +# ArgumentParser.validate_prepare_opts(opts) +# return opts # # # default_opts = get_default_opts() @@ -40,15 +40,15 @@ # class TestData(unittest.TestCase): # def __init__(self, *args, **kwargs): # super(TestData, self).__init__(*args, **kwargs) -# self.opt = default_opts +# self.opts = default_opts # -# def dataset_build(self, opt): +# def dataset_build(self, opts): # try: -# prepare_fields_transforms(opt) +# prepare_fields_transforms(opts) # except SystemExit as err: # print(err) # except IOError as err: -# if opt.skip_empty_level != 'error': +# if opts.skip_empty_level != 'error': # raise err # else: # print(f"Catched IOError: {err}") @@ -56,10 +56,10 @@ # # Remove the generated *pt files. # for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'): # os.remove(pt) -# if self.opt.save_data: +# if self.opts.save_data: # # Remove the generated data samples # sample_path = os.path.join( -# os.path.dirname(self.opt.save_data), +# os.path.dirname(self.opts.save_data), # CorpusName.SAMPLE) # if os.path.exists(sample_path): # for f in glob.glob(sample_path + '/*'): @@ -78,12 +78,12 @@ # # def test_method(self): # if param_setting: -# opt = copy.deepcopy(self.opt) +# opts = copy.deepcopy(self.opts) # for param, setting in param_setting: -# setattr(opt, param, setting) +# setattr(opts, param, setting) # else: -# opt = self.opt -# getattr(self, methodname)(opt) +# opts = self.opts +# getattr(self, methodname)(opts) # if param_setting: # name = 'test_' + methodname + "_" + "_".join( # str(param_setting).split()) diff --git a/onmt/tests/test_embeddings.py b/mammoth/tests/test_embeddings.py similarity index 98% rename from onmt/tests/test_embeddings.py rename to mammoth/tests/test_embeddings.py index a152838d..9abac740 100644 --- a/onmt/tests/test_embeddings.py +++ b/mammoth/tests/test_embeddings.py @@ -1,12 +1,12 @@ import unittest -from onmt.modules.embeddings import Embeddings +from mammoth.modules.embeddings import Embeddings import itertools from copy import deepcopy import torch -from onmt.tests.utils_for_tests import product_dict +from mammoth.tests.utils_for_tests import product_dict class TestEmbeddings(unittest.TestCase): diff --git a/onmt/tests/test_greedy_search.py b/mammoth/tests/test_greedy_search.py similarity index 99% rename from onmt/tests/test_greedy_search.py rename to mammoth/tests/test_greedy_search.py index b32e016f..6c718e69 100644 --- a/onmt/tests/test_greedy_search.py +++ b/mammoth/tests/test_greedy_search.py @@ -1,5 +1,5 @@ import unittest -from onmt.translate.greedy_search import GreedySearch +from mammoth.translate.greedy_search import GreedySearch import torch diff --git a/onmt/tests/test_model.pt b/mammoth/tests/test_model.pt similarity index 100% rename from onmt/tests/test_model.pt rename to mammoth/tests/test_model.pt diff --git a/onmt/tests/test_model2.pt b/mammoth/tests/test_model2.pt similarity index 100% rename from onmt/tests/test_model2.pt rename to mammoth/tests/test_model2.pt diff --git a/onmt/tests/test_model_lm.pt b/mammoth/tests/test_model_lm.pt similarity index 100% rename from onmt/tests/test_model_lm.pt rename to mammoth/tests/test_model_lm.pt diff --git a/onmt/tests/test_models.py b/mammoth/tests/test_models.py similarity index 68% rename from onmt/tests/test_models.py rename to mammoth/tests/test_models.py index 1d70a3cc..089ca796 100644 --- a/onmt/tests/test_models.py +++ b/mammoth/tests/test_models.py @@ -3,24 +3,24 @@ import torch -import onmt -import onmt.opts -from onmt.model_builder import build_embeddings, build_encoder, build_decoder -from onmt.inputters.vocab import Vocab, DEFAULT_SPECIALS -from onmt.utils.parse import ArgumentParser +import mammoth +import mammoth.opts +from mammoth.model_builder import build_embeddings, build_encoder, build_decoder +from mammoth.inputters.vocab import Vocab, DEFAULT_SPECIALS +from mammoth.utils.parse import ArgumentParser parser = ArgumentParser(description='train.py') -onmt.opts.model_opts(parser) -onmt.opts._add_train_general_opts(parser) +mammoth.opts.model_opts(parser) +mammoth.opts._add_train_general_opts(parser) # -data option is required, but not used in this test, so dummy. -opt = parser.parse_known_args(['-data', 'dummy', '-node_rank', '0'])[0] +opts = parser.parse_known_args(['-tasks', 'dummy', '-node_rank', '0', '-model_dim', '500'])[0] class TestModel(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestModel, self).__init__(*args, **kwargs) - self.opt = opt + self.opts = opts def get_field(self): return Vocab(None, items=[], tag='dummy', specials=list(DEFAULT_SPECIALS)) @@ -32,82 +32,77 @@ def get_batch(self, source_l=3, bsize=1): test_length = torch.ones(bsize).fill_(source_l).long() return test_src, test_tgt, test_length - def embeddings_forward(self, opt, source_l=3, bsize=1): + def embeddings_forward(self, opts, source_l=3, bsize=1): ''' Tests if the embeddings works as expected args: - opt: set of options + opts: set of options source_l: Length of generated input sentence bsize: Batchsize of generated input ''' word_field = self.get_field() - emb = build_embeddings(opt, word_field) + emb = build_embeddings(opts, word_field) test_src, _, __ = self.get_batch(source_l=source_l, bsize=bsize) - if opt.decoder_type == 'transformer': + if opts.decoder_type == 'transformer': input = torch.cat([test_src, test_src], 0) res = emb(input) - compare_to = torch.zeros(source_l * 2, bsize, opt.src_word_vec_size) + compare_to = torch.zeros(source_l * 2, bsize, opts.model_dim) else: res = emb(test_src) - compare_to = torch.zeros(source_l, bsize, opt.src_word_vec_size) + compare_to = torch.zeros(source_l, bsize, opts.model_dim) self.assertEqual(res.size(), compare_to.size()) - def encoder_forward(self, opt, source_l=3, bsize=1): + def encoder_forward(self, opts, source_l=3, bsize=1): ''' Tests if the encoder works as expected args: - opt: set of options + opts: set of options source_l: Length of generated input sentence bsize: Batchsize of generated input ''' - if opt.rnn_size > 0: - opt.enc_rnn_size = opt.rnn_size word_field = self.get_field() - embeddings = build_embeddings(opt, word_field) - enc = build_encoder(opt, embeddings) + embeddings = build_embeddings(opts, word_field) + enc = build_encoder(opts, embeddings) test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) hidden_t, outputs, test_length = enc(test_src, test_length) # Initialize vectors to compare size with - test_hid = torch.zeros(self.opt.enc_layers, bsize, opt.enc_rnn_size) - test_out = torch.zeros(source_l, bsize, opt.dec_rnn_size) + test_hid = torch.zeros(self.opts.enc_layers, bsize, opts.model_dim) + test_out = torch.zeros(source_l, bsize, opts.model_dim) # Ensure correct sizes and types self.assertEqual(test_hid.size(), hidden_t[0].size(), hidden_t[1].size()) self.assertEqual(test_out.size(), outputs.size()) self.assertEqual(type(outputs), torch.Tensor) - def nmtmodel_forward(self, opt, source_l=3, bsize=1): + def nmtmodel_forward(self, opts, source_l=3, bsize=1): """ - Creates a nmtmodel with a custom opt function. + Creates a nmtmodel with a custom opts function. Forwards a testbatch and checks output size. Args: - opt: Namespace with options + opts: Namespace with options source_l: length of input sequence bsize: batchsize """ - if opt.rnn_size > 0: - opt.enc_rnn_size = opt.rnn_size - opt.dec_rnn_size = opt.rnn_size word_field = self.get_field() - embeddings = build_embeddings(opt, word_field) - enc = build_encoder(opt, embeddings) + embeddings = build_embeddings(opts, word_field) + enc = build_encoder(opts, embeddings) - embeddings = build_embeddings(opt, word_field, for_encoder=False) - dec = build_decoder(opt, embeddings) + embeddings = build_embeddings(opts, word_field, for_encoder=False) + dec = build_decoder(opts, embeddings) - model = onmt.models.model.NMTModel(enc, dec) + model = mammoth.models.model.NMTModel(enc, dec) test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) outputs, attn = model(test_src, test_tgt, test_length) - outputsize = torch.zeros(source_l - 1, bsize, opt.dec_rnn_size) + outputsize = torch.zeros(source_l - 1, bsize, opts.model_dim) # Make sure that output has the correct size and type self.assertEqual(outputs.size(), outputsize.size()) self.assertEqual(type(outputs), torch.Tensor) @@ -123,12 +118,12 @@ def _add_test(param_setting, methodname): """ def test_method(self): - opt = copy.deepcopy(self.opt) + opts = copy.deepcopy(self.opts) if param_setting: for param, setting in param_setting: - setattr(opt, param, setting) - ArgumentParser.update_model_opts(opt) - getattr(self, methodname)(opt) + setattr(opts, param, setting) + ArgumentParser.update_model_opts(opts) + getattr(self, methodname)(opts) if param_setting: name = 'test_' + methodname + "_" + "_".join(str(param_setting).split()) @@ -141,7 +136,7 @@ def test_method(self): ''' TEST PARAMETERS ''' -opt.brnn = False +opts.brnn = False # FIXME: Most tests disabled: FoTraNMT only supports Transformer test_embeddings = [ @@ -156,7 +151,7 @@ def test_method(self): tests_encoder = [ # [], # [('encoder_type', 'mean')], - # [('encoder_type', 'transformer'), ('word_vec_size', 16), ('rnn_size', 16)], + # [('encoder_type', 'transformer'), ('word_vec_size', 16), ('model_dim', 16)], # [], ] @@ -173,14 +168,14 @@ def test_method(self): ('encoder_type', 'transformer'), ('src_word_vec_size', 16), ('tgt_word_vec_size', 16), - ('rnn_size', 16), + ('model_dim', 16), ], [ ('decoder_type', 'transformer'), ('encoder_type', 'transformer'), ('src_word_vec_size', 16), ('tgt_word_vec_size', 16), - ('rnn_size', 16), + ('model_dim', 16), ('position_encoding', True), ], # [('coverage_attn', True)], @@ -198,10 +193,6 @@ def test_method(self): # [], ] -if onmt.models.sru.check_sru_requirement(): - # """ Only do SRU test if requirment is safisfied. """ - # SRU doesn't support input_feed. - tests_nmtmodel.append([('rnn_type', 'SRU'), ('input_feed', 0)]) # ## FIXME: Broken in FoTraNMT # for p in tests_nmtmodel: diff --git a/onmt/tests/test_models.sh b/mammoth/tests/test_models.sh similarity index 100% rename from onmt/tests/test_models.sh rename to mammoth/tests/test_models.sh diff --git a/onmt/tests/test_simple.py b/mammoth/tests/test_simple.py similarity index 50% rename from onmt/tests/test_simple.py rename to mammoth/tests/test_simple.py index bd607e57..abdafbda 100644 --- a/onmt/tests/test_simple.py +++ b/mammoth/tests/test_simple.py @@ -1,6 +1,6 @@ -import onmt +import mammoth def test_load(): - onmt + mammoth pass diff --git a/onmt/tests/test_subword_marker.py b/mammoth/tests/test_subword_marker.py similarity index 98% rename from onmt/tests/test_subword_marker.py rename to mammoth/tests/test_subword_marker.py index afa17fcf..63dd57c9 100644 --- a/onmt/tests/test_subword_marker.py +++ b/mammoth/tests/test_subword_marker.py @@ -1,8 +1,8 @@ import unittest -from onmt.transforms.denoising import word_start_finder -from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer -from onmt.constants import SubwordMarker +from mammoth.transforms.denoising import word_start_finder +from mammoth.utils.alignment import subword_map_by_joiner, subword_map_by_spacer +from mammoth.constants import SubwordMarker class TestWordStartFinder(unittest.TestCase): diff --git a/onmt/tests/test_task_distribution_strategy.py b/mammoth/tests/test_task_distribution_strategy.py similarity index 86% rename from onmt/tests/test_task_distribution_strategy.py rename to mammoth/tests/test_task_distribution_strategy.py index aba90222..612d9140 100644 --- a/onmt/tests/test_task_distribution_strategy.py +++ b/mammoth/tests/test_task_distribution_strategy.py @@ -1,11 +1,11 @@ import pytest from argparse import Namespace -from onmt.utils.distributed import WeightedSamplingTaskDistributionStrategy, RoundRobinTaskDistributionStrategy +from mammoth.distributed.tasks import WeightedSamplingTaskDistributionStrategy, RoundRobinTaskDistributionStrategy def test_weights_all_zero(): - opt = Namespace(data={ + opts = Namespace(data={ 'a': { 'weight': 0, 'introduce_at_training_step': 0, @@ -20,12 +20,12 @@ def test_weights_all_zero(): }, }) with pytest.raises(ValueError) as exc_info: - WeightedSamplingTaskDistributionStrategy.from_opt(['a', 'b'], opt) + WeightedSamplingTaskDistributionStrategy.from_opts(['a', 'b'], opts) assert 'Can not set "weight" of all corpora on a device to zero' in str(exc_info.value) def test_weights_all_postponed(): - opt = Namespace(data={ + opts = Namespace(data={ 'a': { 'weight': 1, 'introduce_at_training_step': 1, @@ -40,12 +40,12 @@ def test_weights_all_postponed(): }, }) with pytest.raises(ValueError) as exc_info: - WeightedSamplingTaskDistributionStrategy.from_opt(['a', 'b'], opt) + WeightedSamplingTaskDistributionStrategy.from_opts(['a', 'b'], opts) assert 'Can not set "introduce_at_training_step" of all corpora on a device to nonzero' in str(exc_info.value) def test_invalid_curriculum(): - opt = Namespace(data={ + opts = Namespace(data={ # 'a' disabled by weight 'a': { 'weight': 0, @@ -62,12 +62,12 @@ def test_invalid_curriculum(): }, }) with pytest.raises(ValueError) as exc_info: - WeightedSamplingTaskDistributionStrategy.from_opt(['a', 'b'], opt) + WeightedSamplingTaskDistributionStrategy.from_opts(['a', 'b'], opts) assert 'Invalid curriculum' in str(exc_info.value) def test_sampling_task_distribution_strategy(): - opt = Namespace(data={ + opts = Namespace(data={ # 'a' disabled by weight 'a': { 'weight': 0, @@ -89,7 +89,7 @@ def test_sampling_task_distribution_strategy(): 'introduce_at_training_step': 0, }, }) - strategy = WeightedSamplingTaskDistributionStrategy.from_opt(['a', 'b', 'c'], opt) + strategy = WeightedSamplingTaskDistributionStrategy.from_opts(['a', 'b', 'c'], opts) all_samples = [] n_samples = 10 n_batches = 1000 diff --git a/onmt/tests/test_task_queue_manager.py b/mammoth/tests/test_task_queue_manager.py similarity index 88% rename from onmt/tests/test_task_queue_manager.py rename to mammoth/tests/test_task_queue_manager.py index efc75edd..17155080 100644 --- a/onmt/tests/test_task_queue_manager.py +++ b/mammoth/tests/test_task_queue_manager.py @@ -3,7 +3,7 @@ from collections import OrderedDict from unittest.mock import MagicMock -from onmt.utils.distributed import TaskQueueManager, WorldContext +from mammoth.distributed import TaskQueueManager, WorldContext def test_init_minimal(): @@ -20,9 +20,9 @@ def test_init_minimal(): 'train_c-d': {'path_src': 'dummy', 'path_tgt': 'dummy', 'src_tgt': 'c-d'}, } } - opt = Namespace(**opt_dict) - world_context = WorldContext.from_opt(opt) - task_queue_manager = TaskQueueManager.from_opt(opt, world_context) + opts = Namespace(**opt_dict) + world_context = WorldContext.from_opts(opts) + task_queue_manager = TaskQueueManager.from_opts(opts, world_context) assert world_context.is_gpu() assert world_context.is_distributed() assert len(task_queue_manager.tasks) == 2 @@ -95,15 +95,15 @@ def create_basic_task_queue_manager(): }, } } - opt = Namespace(**opt_dict) - world_context = WorldContext.from_opt(opt) - task_queue_manager = TaskQueueManager.from_opt(opt, world_context) - return task_queue_manager, opt + opts = Namespace(**opt_dict) + world_context = WorldContext.from_opts(opts) + task_queue_manager = TaskQueueManager.from_opts(opts, world_context) + return task_queue_manager, opts def test_init_basic(): - global_task_queue_manager, opt = create_basic_task_queue_manager() - task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opt=opt) + global_task_queue_manager, opts = create_basic_task_queue_manager() + task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opts=opts) world_context = task_queue_manager.world_context assert world_context.is_gpu() assert world_context.is_distributed() @@ -129,7 +129,7 @@ def __call__(self, sorted_global_ranks): self.group_idx += 1 return result - global_task_queue_manager, opt = create_basic_task_queue_manager() + global_task_queue_manager, opts = create_basic_task_queue_manager() all_groups = global_task_queue_manager.create_all_distributed_groups(new_group_func=MockGroup()) assert all_groups == { 'src_emb': OrderedDict({ @@ -157,8 +157,8 @@ def __call__(self, sorted_global_ranks): self.group_idx += 1 return result - global_task_queue_manager, opt = create_basic_task_queue_manager() - task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opt=opt) + global_task_queue_manager, opts = create_basic_task_queue_manager() + task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opts=opts) my_groups = task_queue_manager.get_distributed_groups(new_group_func=MockGroup()) assert my_groups == { 'encoder': OrderedDict({ @@ -196,10 +196,10 @@ def test_cpu_distributed_groups(): }, } } - opt = Namespace(**opt_dict) - world_context = WorldContext.from_opt(opt) - global_task_queue_manager = TaskQueueManager.from_opt(opt, world_context) - task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opt=opt) + opts = Namespace(**opt_dict) + world_context = WorldContext.from_opts(opts) + global_task_queue_manager = TaskQueueManager.from_opts(opts, world_context) + task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opts=opts) new_group_func = MagicMock().new_group_func my_groups = task_queue_manager.get_distributed_groups(new_group_func=new_group_func) # No groups should be created when running on CPU @@ -248,10 +248,10 @@ def test_distributed_groups_no_encoder_group(): }, } } - opt = Namespace(**opt_dict) - world_context = WorldContext.from_opt(opt) - global_task_queue_manager = TaskQueueManager.from_opt(opt, world_context) - task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opt=opt) + opts = Namespace(**opt_dict) + world_context = WorldContext.from_opts(opts) + global_task_queue_manager = TaskQueueManager.from_opts(opts, world_context) + task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opts=opts) new_group_func = MagicMock().new_group_func my_groups = task_queue_manager.get_distributed_groups(new_group_func=new_group_func) # No groups should be created: @@ -269,20 +269,20 @@ def test_distributed_groups_no_encoder_group(): # (side, lang): f'{side} {lang}' for (side, lang) in # [('src', 'a'), ('src', 'c'), ('src', 'e'), ('tgt', 'b'), ('tgt', 'd')] # } -# global_task_queue_manager, opt = create_basic_task_queue_manager() -# task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opt=opt) +# global_task_queue_manager, opts = create_basic_task_queue_manager() +# task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opts=opts) # fields = task_queue_manager.get_fields('src', mock_fields) # assert fields == [('src', 'a', None, 'src a')] # fields = task_queue_manager.get_fields('tgt', mock_fields) # assert fields == [('tgt', 'b', None, 'tgt b')] # -# task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opt=opt) +# task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opts=opts) # fields = task_queue_manager.get_fields('src', mock_fields) # assert fields == [('src', 'c', None, 'src c'), ('src', 'a', 'x', 'src a')] # fields = task_queue_manager.get_fields('tgt', mock_fields) # assert fields == [('tgt', 'd', None, 'tgt d')] # -# task_queue_manager = global_task_queue_manager.global_to_local(node_rank=1, local_rank=0, opt=opt) +# task_queue_manager = global_task_queue_manager.global_to_local(node_rank=1, local_rank=0, opts=opts) # fields = task_queue_manager.get_fields('src', mock_fields) # assert fields == [('src', 'e', None, 'src e')] # fields = task_queue_manager.get_fields('tgt', mock_fields) @@ -290,8 +290,8 @@ def test_distributed_groups_no_encoder_group(): def test_basic_getters(): - global_task_queue_manager, opt = create_basic_task_queue_manager() - task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opt=opt) + global_task_queue_manager, opts = create_basic_task_queue_manager() + task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opts=opts) encoders = list(task_queue_manager.get_encoders(0)) assert encoders == ['x'] decoders = list(task_queue_manager.get_decoders(0)) @@ -303,7 +303,7 @@ def test_basic_getters(): generators = list(task_queue_manager.get_generators()) assert generators == ['b'] - task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opt=opt) + task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opts=opts) encoders = list(task_queue_manager.get_encoders(0)) assert encoders == ['xx', 'x'] decoders = list(task_queue_manager.get_decoders(0)) diff --git a/onmt/tests/test_text_dataset.py b/mammoth/tests/test_text_dataset.py similarity index 99% rename from onmt/tests/test_text_dataset.py rename to mammoth/tests/test_text_dataset.py index 0fffe0ca..9d6f3b56 100644 --- a/onmt/tests/test_text_dataset.py +++ b/mammoth/tests/test_text_dataset.py @@ -7,7 +7,7 @@ # # # from torchtext.legacy.data import Field # -# from onmt.tests.utils_for_tests import product_dict +# from mammoth.tests.utils_for_tests import product_dict # # # class TestTextMultiField(unittest.TestCase): diff --git a/onmt/tests/test_transform.py b/mammoth/tests/test_transform.py similarity index 91% rename from onmt/tests/test_transform.py rename to mammoth/tests/test_transform.py index 32e25f83..e53ec3e5 100644 --- a/onmt/tests/test_transform.py +++ b/mammoth/tests/test_transform.py @@ -5,13 +5,13 @@ import yaml import math from argparse import Namespace -from onmt.transforms import ( +from mammoth.transforms import ( get_transforms_cls, get_specials, make_transforms, TransformPipe, ) -from onmt.transforms.denoising import BARTNoising +from mammoth.transforms.denoising import BARTNoising class TestTransform(unittest.TestCase): @@ -31,13 +31,13 @@ def test_transform_register(self): def test_vocab_required_transform(self): transforms_cls = get_transforms_cls(["denoising", "switchout"]) - opt = Namespace(seed=-1, switchout_temperature=1.0) + opts = Namespace(seed=-1, switchout_temperature=1.0) # transforms that require vocab will not create if not provide vocab - transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None) + transforms = make_transforms(opts, transforms_cls, vocabs=None, task=None) self.assertEqual(len(transforms), 0) with self.assertRaises(ValueError): - transforms_cls["switchout"](opt).warm_up(vocabs=None) - transforms_cls["denoising"](opt).warm_up(vocabs=None) + transforms_cls["switchout"](opts).warm_up(vocabs=None) + transforms_cls["denoising"](opts).warm_up(vocabs=None) def test_transform_specials(self): transforms_cls = get_transforms_cls(["prefix"]) @@ -52,8 +52,8 @@ def test_transform_specials(self): tgt_prefix: "⦅_pf_tgt⦆" """ ) - opt = Namespace(data=corpora) - specials = get_specials(opt, transforms_cls) + opts = Namespace(tasks=corpora) + specials = get_specials(opts, transforms_cls) specials_expected = {"src": {"⦅_pf_src⦆"}, "tgt": {"⦅_pf_tgt⦆"}} self.assertEqual(specials, specials_expected) @@ -71,13 +71,13 @@ def test_transform_pipe(self): tgt_prefix: "⦅_pf_tgt⦆" """ ) - opt = Namespace(data=corpora, seed=-1) - prefix_transform = prefix_cls(opt) + opts = Namespace(tasks=corpora, seed=-1) + prefix_transform = prefix_cls(opts) prefix_transform.warm_up() # 2. Init second transform in the pipe filter_cls = get_transforms_cls(["filtertoolong"])["filtertoolong"] - opt = Namespace(src_seq_length=4, tgt_seq_length=4) - filter_transform = filter_cls(opt) + opts = Namespace(src_seq_length=4, tgt_seq_length=4) + filter_transform = filter_cls(opts) # 3. Sequential combine them into a transform pipe transform_pipe = TransformPipe.build_from([prefix_transform, filter_transform]) ex = { @@ -110,8 +110,8 @@ def test_prefix(self): tgt_prefix: "⦅_pf_tgt⦆" """ ) - opt = Namespace(data=corpora, seed=-1) - prefix_transform = prefix_cls(opt) + opts = Namespace(tasks=corpora, seed=-1) + prefix_transform = prefix_cls(opts) prefix_transform.warm_up() self.assertIn("trainset", prefix_transform.prefix_dict) @@ -128,8 +128,8 @@ def test_prefix(self): def test_filter_too_long(self): filter_cls = get_transforms_cls(["filtertoolong"])["filtertoolong"] - opt = Namespace(src_seq_length=100, tgt_seq_length=100) - filter_transform = filter_cls(opt) + opts = Namespace(src_seq_length=100, tgt_seq_length=100) + filter_transform = filter_cls(opts) # filter_transform.warm_up() ex_in = { "src": ["Hello", "world", "."], @@ -162,9 +162,9 @@ def setUpClass(cls): def test_bpe(self): bpe_cls = get_transforms_cls(["bpe"])["bpe"] - opt = Namespace(**self.base_opts) - bpe_cls._validate_options(opt) - bpe_transform = bpe_cls(opt) + opts = Namespace(**self.base_opts) + bpe_cls._validate_options(opts) + bpe_transform = bpe_cls(opts) bpe_transform.warm_up() ex = { "src": ["Hello", "world", "."], @@ -212,9 +212,9 @@ def test_sentencepiece(self): base_opt = copy.copy(self.base_opts) base_opt["src_subword_model"] = "data/sample.sp.model" base_opt["tgt_subword_model"] = "data/sample.sp.model" - opt = Namespace(**base_opt) - sp_cls._validate_options(opt) - sp_transform = sp_cls(opt) + opts = Namespace(**base_opt) + sp_cls._validate_options(opts) + sp_transform = sp_cls(opts) sp_transform.warm_up() ex = { "src": ["Hello", "world", "."], @@ -245,9 +245,9 @@ def test_pyonmttok_bpe(self): onmt_args = "{'mode': 'space', 'joiner_annotate': True}" base_opt["src_onmttok_kwargs"] = onmt_args base_opt["tgt_onmttok_kwargs"] = onmt_args - opt = Namespace(**base_opt) - onmttok_cls._validate_options(opt) - onmttok_transform = onmttok_cls(opt) + opts = Namespace(**base_opt) + onmttok_cls._validate_options(opts) + onmttok_transform = onmttok_cls(opts) onmttok_transform.warm_up() ex = { "src": ["Hello", "world", "."], @@ -270,9 +270,9 @@ def test_pyonmttok_sp(self): onmt_args = "{'mode': 'none', 'spacer_annotate': True}" base_opt["src_onmttok_kwargs"] = onmt_args base_opt["tgt_onmttok_kwargs"] = onmt_args - opt = Namespace(**base_opt) - onmttok_cls._validate_options(opt) - onmttok_transform = onmttok_cls(opt) + opts = Namespace(**base_opt) + onmttok_cls._validate_options(opts) + onmttok_transform = onmttok_cls(opts) onmttok_transform.warm_up() ex = { "src": ["Hello", "world", "."], @@ -289,8 +289,8 @@ def test_pyonmttok_sp(self): class TestSamplingTransform(unittest.TestCase): def test_tokendrop(self): tokendrop_cls = get_transforms_cls(["tokendrop"])["tokendrop"] - opt = Namespace(seed=3434, tokendrop_temperature=0.1) - tokendrop_transform = tokendrop_cls(opt) + opts = Namespace(seed=3434, tokendrop_temperature=0.1) + tokendrop_transform = tokendrop_cls(opts) tokendrop_transform.warm_up() ex = { "src": ["Hello", ",", "world", "."], @@ -305,8 +305,8 @@ def test_tokendrop(self): def test_tokenmask(self): tokenmask_cls = get_transforms_cls(["tokenmask"])["tokenmask"] - opt = Namespace(seed=3434, tokenmask_temperature=0.1) - tokenmask_transform = tokenmask_cls(opt) + opts = Namespace(seed=3434, tokenmask_temperature=0.1) + tokenmask_transform = tokenmask_cls(opts) tokenmask_transform.warm_up() ex = { "src": ["Hello", ",", "world", "."], @@ -321,8 +321,8 @@ def test_tokenmask(self): def test_switchout(self): switchout_cls = get_transforms_cls(["switchout"])["switchout"] - opt = Namespace(seed=3434, switchout_temperature=0.1) - switchout_transform = switchout_cls(opt) + opts = Namespace(seed=3434, switchout_temperature=0.1) + switchout_transform = switchout_cls(opts) with self.assertRaises(ValueError): # require vocabs to warm_up switchout_transform.warm_up(vocabs=None) @@ -518,16 +518,16 @@ def test_span_infilling(self): def test_vocab_required_transform(self): transforms_cls = get_transforms_cls(["denoising"]) - opt = Namespace(random_ratio=1, denoising_objective='mass') + opts = Namespace(random_ratio=1, denoising_objective='mass') with self.assertRaises(ValueError): - make_transforms(opt, transforms_cls, vocabs=None, task=None) + make_transforms(opts, transforms_cls, vocabs=None, task=None) class TestFeaturesTransform(unittest.TestCase): def test_inferfeats(self): inferfeats_cls = get_transforms_cls(["inferfeats"])["inferfeats"] - opt = Namespace(reversible_tokenization="joiner", prior_tokenization=False) - inferfeats_transform = inferfeats_cls(opt) + opts = Namespace(reversible_tokenization="joiner", prior_tokenization=False) + inferfeats_transform = inferfeats_cls(opts) ex_in = { "src": [ diff --git a/onmt/tests/test_translation_server.py b/mammoth/tests/test_translation_server.py similarity index 85% rename from onmt/tests/test_translation_server.py rename to mammoth/tests/test_translation_server.py index 07489a25..c639f3b5 100644 --- a/onmt/tests/test_translation_server.py +++ b/mammoth/tests/test_translation_server.py @@ -1,12 +1,12 @@ import unittest -from onmt.translate.translation_server import ServerModel, TranslationServer +from mammoth.translate.translation_server import ServerModel, TranslationServer import os from textwrap import dedent import torch -from onmt.translate.translator import Translator +from mammoth.translate.translator import Translator TEST_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -16,9 +16,9 @@ class TestServerModel(unittest.TestCase): @unittest.skip('Broken in FoTraNMT') # FIXME def test_deferred_loading_model_and_unload(self): model_id = 0 - opt = {"models": ["test_model.pt"]} + opts = {"models": ["test_model.pt"]} model_root = TEST_DIR - sm = ServerModel(opt, model_id, model_root=model_root, load=False) + sm = ServerModel(opts, model_id, model_root=model_root, load=False) self.assertFalse(sm.loaded) sm.load() self.assertTrue(sm.loaded) @@ -29,9 +29,9 @@ def test_deferred_loading_model_and_unload(self): @unittest.skip('Broken in FoTraNMT') # FIXME def test_load_model_on_init_and_unload(self): model_id = 0 - opt = {"models": ["test_model.pt"]} + opts = {"models": ["test_model.pt"]} model_root = TEST_DIR - sm = ServerModel(opt, model_id, model_root=model_root, load=True) + sm = ServerModel(opts, model_id, model_root=model_root, load=True) self.assertTrue(sm.loaded) self.assertIsInstance(sm.translator, Translator) sm.unload() @@ -40,18 +40,18 @@ def test_load_model_on_init_and_unload(self): @unittest.skip('Broken in FoTraNMT') # FIXME def test_tokenizing_with_no_tokenizer_fails(self): model_id = 0 - opt = {"models": ["test_model.pt"]} + opts = {"models": ["test_model.pt"]} model_root = TEST_DIR - sm = ServerModel(opt, model_id, model_root=model_root, load=True) + sm = ServerModel(opts, model_id, model_root=model_root, load=True) with self.assertRaises(ValueError): sm.tokenize("hello world") @unittest.skip('Broken in FoTraNMT') # FIXME def test_detokenizing_with_no_tokenizer_fails(self): model_id = 0 - opt = {"models": ["test_model.pt"]} + opts = {"models": ["test_model.pt"]} model_root = TEST_DIR - sm = ServerModel(opt, model_id, model_root=model_root, load=True) + sm = ServerModel(opts, model_id, model_root=model_root, load=True) with self.assertRaises(ValueError): sm.detokenize("hello world") @@ -60,9 +60,9 @@ def test_detokenizing_with_no_tokenizer_fails(self): def test_moving_to_gpu_and_back(self): torch.cuda.set_device(torch.device("cuda", 0)) model_id = 0 - opt = {"models": ["test_model.pt"]} + opts = {"models": ["test_model.pt"]} model_root = TEST_DIR - sm = ServerModel(opt, model_id, model_root=model_root, load=True) + sm = ServerModel(opts, model_id, model_root=model_root, load=True) for p in sm.translator.model.parameters(): self.assertEqual(p.device.type, "cpu") sm.to_gpu() @@ -76,9 +76,9 @@ def test_moving_to_gpu_and_back(self): def test_initialize_on_gpu_and_move_back(self): torch.cuda.set_device(torch.device("cuda", 0)) model_id = 0 - opt = {"models": ["test_model.pt"], "gpu": 0} + opts = {"models": ["test_model.pt"], "gpu": 0} model_root = TEST_DIR - sm = ServerModel(opt, model_id, model_root=model_root, load=True) + sm = ServerModel(opts, model_id, model_root=model_root, load=True) for p in sm.translator.model.parameters(): self.assertEqual(p.device.type, "cuda") self.assertEqual(p.device.index, 0) @@ -95,9 +95,9 @@ def test_initialize_on_gpu_and_move_back(self): def test_initialize_on_nonzero_gpu_and_back(self): torch.cuda.set_device(torch.device("cuda", 1)) model_id = 0 - opt = {"models": ["test_model.pt"], "gpu": 1} + opts = {"models": ["test_model.pt"], "gpu": 1} model_root = TEST_DIR - sm = ServerModel(opt, model_id, model_root=model_root, load=True) + sm = ServerModel(opts, model_id, model_root=model_root, load=True) for p in sm.translator.model.parameters(): self.assertEqual(p.device.type, "cuda") self.assertEqual(p.device.index, 1) @@ -112,9 +112,9 @@ def test_initialize_on_nonzero_gpu_and_back(self): @unittest.skip('Broken in FoTraNMT') # FIXME def test_run(self): model_id = 0 - opt = {"models": ["test_model.pt"]} + opts = {"models": ["test_model.pt"]} model_root = TEST_DIR - sm = ServerModel(opt, model_id, model_root=model_root, load=True) + sm = ServerModel(opts, model_id, model_root=model_root, load=True) inp = [{"src": "hello how are you today"}, {"src": "good morning to you ."}] results, scores, n_best, time, aligns = sm.run(inp) self.assertIsInstance(results, list) @@ -160,7 +160,7 @@ def write(self, cfg): "timeout": -1, "on_timeout": "to_cpu", "load": false, - "opt": { + "opts": { "beam_size": 5 } } @@ -188,7 +188,7 @@ def test_start_without_initial_loading(self): "timeout": -1, "on_timeout": "to_cpu", "load": true, - "opt": { + "opts": { "beam_size": 5 } } @@ -217,7 +217,7 @@ def test_start_with_initial_loading(self): "timeout": -1, "on_timeout": "to_cpu", "load": true, - "opt": { + "opts": { "beam_size": 5 } }, @@ -227,7 +227,7 @@ def test_start_with_initial_loading(self): "timeout": -1, "on_timeout": "to_cpu", "load": false, - "opt": { + "opts": { "beam_size": 5 } } diff --git a/onmt/tests/test_translator.py b/mammoth/tests/test_translator.py similarity index 96% rename from onmt/tests/test_translator.py rename to mammoth/tests/test_translator.py index 78ffe60b..8107e2e6 100644 --- a/onmt/tests/test_translator.py +++ b/mammoth/tests/test_translator.py @@ -1,5 +1,5 @@ import unittest -from onmt.translate import GeneratorLM +from mammoth.translate import GeneratorLM import torch diff --git a/onmt/tests/utils_for_tests.py b/mammoth/tests/utils_for_tests.py similarity index 100% rename from onmt/tests/utils_for_tests.py rename to mammoth/tests/utils_for_tests.py diff --git a/onmt/train_single.py b/mammoth/train_single.py similarity index 74% rename from onmt/train_single.py rename to mammoth/train_single.py index b396c5f4..f036f0d9 100644 --- a/onmt/train_single.py +++ b/mammoth/train_single.py @@ -3,53 +3,54 @@ import torch import time -from onmt.model_builder import build_model -from onmt.utils.optimizers import Optimizer -from onmt.utils.misc import set_random_seed -from onmt.trainer import build_trainer -from onmt.models import build_model_saver -from onmt.utils.logging import init_logger, logger -from onmt.utils.parse import ArgumentParser +from mammoth.model_builder import build_model +from mammoth.utils.optimizers import Optimizer +from mammoth.utils.misc import set_random_seed +from mammoth.trainer import build_trainer +from mammoth.models import build_model_saver +from mammoth.utils.logging import init_logger, logger +from mammoth.utils.parse import ArgumentParser -from onmt.utils.distributed import broadcast_tensors -from onmt.inputters import DynamicDatasetIter -from onmt.transforms import get_transforms_cls +from mammoth.distributed import broadcast_tensors +from mammoth.inputters import DynamicDatasetIter +from mammoth.transforms import get_transforms_cls -def configure_process(opt, device_id): +def configure_process(opts, device_id): logger.info("logger set device {} ".format(device_id)) if device_id >= 0: torch.cuda.set_device(device_id) - set_random_seed(opt.seed, device_id >= 0) + set_random_seed(opts.seed, device_id >= 0) -def _get_model_opts(opt, checkpoint=None): - """Get `model_opt` to build model, may load from `checkpoint` if any.""" +def _get_model_opts(opts, checkpoint=None): + """Get `model_opts` to build model, may load from `checkpoint` if any.""" if checkpoint is not None: - model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) - ArgumentParser.update_model_opts(model_opt) - ArgumentParser.validate_model_opts(model_opt) - if opt.tensorboard_log_dir == model_opt.tensorboard_log_dir and hasattr(model_opt, 'tensorboard_log_dir_dated'): + model_opts = ArgumentParser.ckpt_model_opts(checkpoint["opts"]) + ArgumentParser.update_model_opts(model_opts) + ArgumentParser.validate_model_opts(model_opts) + if opts.tensorboard_log_dir == model_opts.tensorboard_log_dir and \ + hasattr(model_opts, 'tensorboard_log_dir_dated'): # ensure tensorboard output is written in the directory # of previous checkpoints - opt.tensorboard_log_dir_dated = model_opt.tensorboard_log_dir_dated + opts.tensorboard_log_dir_dated = model_opts.tensorboard_log_dir_dated # Override checkpoint's update_embeddings as it defaults to false - model_opt.update_vocab = opt.update_vocab + model_opts.update_vocab = opts.update_vocab else: - model_opt = opt - return model_opt + model_opts = opts + return model_opts -def _build_valid_iter(opt, vocabs_dict, transforms_cls, task_queue_manager): +def _build_valid_iter(opts, vocabs_dict, transforms_cls, task_queue_manager): """Build iterator used for validation.""" - if not any(opt.data[corpus_id].get('path_valid_src', False) for corpus_id in opt.data.keys()): + if not any(opts.tasks[corpus_id].get('path_valid_src', False) for corpus_id in opts.tasks.keys()): return None logger.info("creating validation iterator") valid_iter = DynamicDatasetIter.from_opts( task_queue_manager=task_queue_manager, transforms_cls=transforms_cls, vocabs_dict=vocabs_dict, - opts=opt, + opts=opts, is_train=False, ) return valid_iter @@ -107,7 +108,7 @@ def init_distributed(model, task_queue_manager): def main( - opt, + opts, vocabs_dict, device_context, error_queue=None, @@ -116,26 +117,26 @@ def main( task_queue_manager=None, ): """Start training on `device_id`.""" - # NOTE: It's important that ``opt`` has been validated and updated + # NOTE: It's important that ``opts`` has been validated and updated # at this point. # N.B: task_queue_manager is already local - init_logger(opt.log_file, gpu_id=device_context.id) + init_logger(opts.log_file, gpu_id=device_context.id) if device_context.is_distributed(): sleep_s = device_context.local_rank * 3 logger.warning(f'sleeping {sleep_s}s to alleviate ROCm deadlock') time.sleep(sleep_s) - configure_process(opt, device_context.local_rank) + configure_process(opts, device_context.local_rank) gpu_rank_t = torch.distributed.get_rank() logger.info("RANK GPU FROM TORCH %s", str(gpu_rank_t)) - transforms_cls = get_transforms_cls(opt._all_transform) + transforms_cls = get_transforms_cls(opts._all_transform) checkpoint = None - model_opt = _get_model_opts(opt, checkpoint=checkpoint) + model_opts = _get_model_opts(opts, checkpoint=checkpoint) # Build model. - model, generators_md = build_model(model_opt, opt, vocabs_dict, task_queue_manager, checkpoint) + model, generators_md = build_model(model_opts, opts, vocabs_dict, task_queue_manager, checkpoint) logger.info("{} - Init model".format(device_context.id)) if device_context.is_distributed(): @@ -149,19 +150,19 @@ def main( # Build optimizer. logger.info("{} - Build optimizer".format(device_context.id)) - optim = Optimizer.from_opt( + optim = Optimizer.from_opts( model, - opt, + opts, task_queue_manager=task_queue_manager, checkpoint=checkpoint, ) # Build model saver - model_saver = build_model_saver(model_opt, opt, model, vocabs_dict, optim, device_context) + model_saver = build_model_saver(model_opts, opts, model, vocabs_dict, optim, device_context) logger.info("{} - Build trainer".format(device_context.id)) trainer = build_trainer( - opt, + opts, device_context, model, vocabs_dict, @@ -177,7 +178,7 @@ def main( task_queue_manager=task_queue_manager, transforms_cls=transforms_cls, vocabs_dict=vocabs_dict, - opts=opt, + opts=opts, is_train=True, ) # TODO: check that IterOnDevice is unnecessary here; corpora should be already on device @@ -198,15 +199,15 @@ def _train_iter(): train_iter = _train_iter() # train_iter = iter_on_device(train_iter, device_context) logger.info("Device {} - Valid iter".format(device_context.id)) - valid_iter = _build_valid_iter(opt, vocabs_dict, transforms_cls, task_queue_manager) + valid_iter = _build_valid_iter(opts, vocabs_dict, transforms_cls, task_queue_manager) - if len(opt.gpu_ranks): + if len(opts.gpu_ranks): if device_context.is_master(): - logger.info('Starting training on GPU: %s' % opt.gpu_ranks) + logger.info('Starting training on GPU: %s' % opts.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') - train_steps = opt.train_steps - if opt.single_pass and train_steps > 0: + train_steps = opts.train_steps + if opts.single_pass and train_steps > 0: if device_context.is_master(): logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 @@ -214,9 +215,9 @@ def _train_iter(): trainer.train( train_iter, train_steps, - save_checkpoint_steps=opt.save_checkpoint_steps, + save_checkpoint_steps=opts.save_checkpoint_steps, valid_iter=valid_iter, - valid_steps=opt.valid_steps, + valid_steps=opts.valid_steps, device_context=device_context, ) diff --git a/onmt/trainer.py b/mammoth/trainer.py similarity index 86% rename from onmt/trainer.py rename to mammoth/trainer.py index 0b76a3a3..4f6527e9 100644 --- a/onmt/trainer.py +++ b/mammoth/trainer.py @@ -10,14 +10,14 @@ """ -import onmt.utils +import mammoth.distributed import torch import torch.distributed import torch.nn as nn import traceback from itertools import islice -from onmt.utils.logging import logger +from mammoth.utils.logging import logger def iter_on_device(iterator, device_context): @@ -30,7 +30,7 @@ def iter_on_device(iterator, device_context): def build_trainer( - opt, + opts, device_context, model, vocabs_dict, @@ -40,16 +40,16 @@ def build_trainer( generators_md=None, ): """ - Simplify `Trainer` creation based on user `opt`s* + Simplify `Trainer` creation based on user `opts`s* Args: - opt (:obj:`Namespace`): user options (usually from argument parsing) - model (:obj:`onmt.models.NMTModel`): the model to train + opts (:obj:`Namespace`): user options (usually from argument parsing) + model (:obj:`mammoth.models.NMTModel`): the model to train vocabs_dict (dict): dict of vocabs - optim (:obj:`onmt.utils.Optimizer`): optimizer used during training + optim (:obj:`mammoth.utils.Optimizer`): optimizer used during training data_type (str): string describing the type of data e.g. "text" - model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object + model_saver(:obj:`mammoth.models.ModelSaverBase`): the utility object used to save the model """ @@ -61,32 +61,32 @@ def build_trainer( generator = generators_md[f'generator_{lang}'] train_loss_md.add_module( f'trainloss{lang}', - onmt.utils.loss.build_loss_compute(model, tgt_vocab, opt, train=True, generator=generator), + mammoth.utils.loss.build_loss_compute(model, tgt_vocab, opts, train=True, generator=generator), ) valid_loss_md.add_module( f'valloss{lang}', - onmt.utils.loss.build_loss_compute(model, tgt_vocab, opt, train=False, generator=generator), + mammoth.utils.loss.build_loss_compute(model, tgt_vocab, opts, train=False, generator=generator), ) - trunc_size = opt.truncated_decoder # Badly named... - shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0 - norm_method = opt.normalization - accum_count = opt.accum_count - accum_steps = opt.accum_steps - average_decay = opt.average_decay - average_every = opt.average_every - dropout = opt.dropout - dropout_steps = opt.dropout_steps - gpu_verbose_level = opt.gpu_verbose_level + trunc_size = opts.truncated_decoder # Badly named... + shard_size = opts.max_generator_batches if opts.model_dtype == 'fp32' else 0 + norm_method = opts.normalization + accum_count = opts.accum_count + accum_steps = opts.accum_steps + average_decay = opts.average_decay + average_every = opts.average_every + dropout = opts.dropout + dropout_steps = opts.dropout_steps + gpu_verbose_level = opts.gpu_verbose_level earlystopper = ( - onmt.utils.EarlyStopping(opt.early_stopping, scorers=onmt.utils.scorers_from_opts(opt)) - if opt.early_stopping > 0 + mammoth.utils.EarlyStopping(opts.early_stopping, scorers=mammoth.utils.scorers_from_opts(opts)) + if opts.early_stopping > 0 else None ) - report_manager = onmt.utils.build_report_manager(opt, device_context.node_rank, device_context.local_rank) - trainer = onmt.Trainer( + report_manager = mammoth.utils.build_report_manager(opts, device_context.node_rank, device_context.local_rank) + trainer = mammoth.Trainer( model, train_loss_md, valid_loss_md, @@ -99,16 +99,16 @@ def build_trainer( device_context=device_context, gpu_verbose_level=gpu_verbose_level, report_manager=report_manager, - with_align=True if opt.lambda_align > 0 else False, + with_align=True if opts.lambda_align > 0 else False, model_saver=model_saver, average_decay=average_decay, average_every=average_every, - model_dtype=opt.model_dtype, + model_dtype=opts.model_dtype, earlystopper=earlystopper, dropout=dropout, dropout_steps=dropout_steps, task_queue_manager=task_queue_manager, - report_stats_from_parameters=opt.report_stats_from_parameters, + report_stats_from_parameters=opts.report_stats_from_parameters, ) return trainer @@ -118,13 +118,13 @@ class Trainer(object): Class that controls the training process. Args: - model(:py:class:`onmt.models.model.NMTModel`): translation model + model(:py:class:`mammoth.models.model.NMTModel`): translation model to train - train_loss(:obj:`onmt.utils.loss.LossComputeBase`): + train_loss(:obj:`mammoth.utils.loss.LossComputeBase`): training loss computation - valid_loss(:obj:`onmt.utils.loss.LossComputeBase`): + valid_loss(:obj:`mammoth.utils.loss.LossComputeBase`): training loss computation - optim(:obj:`onmt.utils.optimizers.Optimizer`): + optim(:obj:`mammoth.utils.optimizers.Optimizer`): the optimizer responsible for update trunc_size(int): length of truncated back propagation through time shard_size(int): compute loss in shards of this size for efficiency @@ -132,9 +132,9 @@ class Trainer(object): norm_method(string): normalization methods: [sents|tokens] accum_count(list): accumulate gradients this many times. accum_steps(list): steps for accum gradients changes. - report_manager(:obj:`onmt.utils.ReportMgrBase`): + report_manager(:obj:`mammoth.utils.ReportMgrBase`): the object that creates reports, or None - model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is + model_saver(:obj:`mammoth.models.ModelSaverBase`): the saver is used to save a checkpoint. Thus nothing will be saved if this parameter is None """ @@ -260,8 +260,8 @@ def train( else: logger.info('Start training loop and validate every %d steps...', valid_steps) - total_stats = onmt.utils.Statistics() - report_stats = onmt.utils.Statistics() + total_stats = mammoth.utils.Statistics() + report_stats = mammoth.utils.Statistics() self._start_report_manager(start_time=total_stats.start_time) self.optim.zero_grad() @@ -289,7 +289,7 @@ def train( in self.model.encoder.get_submodule(layer_stack_index, encoder_id).named_parameters() if 'embeddings' not in name and 'adapter' not in name ] - onmt.utils.distributed.only_ready_reduce_and_rescale_grads(params, group=group) + mammoth.distributed.only_ready_reduce_and_rescale_grads(params, group=group) for (layer_stack_index, decoder_id), (_, group) in self.my_decoder_groups.items(): params = [ @@ -297,17 +297,17 @@ def train( in self.model.decoder.get_submodule(layer_stack_index, decoder_id).named_parameters() if 'embeddings' not in name and 'adapter' not in name ] - onmt.utils.distributed.only_ready_reduce_and_rescale_grads(params, group=group) + mammoth.distributed.only_ready_reduce_and_rescale_grads(params, group=group) for (src_lang,), (_, group) in self.my_src_emb_groups.items(): embs = self.model.encoder.embeddings[f'embeddings_{src_lang}'] - onmt.utils.distributed.only_ready_reduce_and_rescale_grads(embs.named_parameters(), group=group) + mammoth.distributed.only_ready_reduce_and_rescale_grads(embs.named_parameters(), group=group) for (tgt_lang,), (_, group) in self.my_tgt_emb_groups.items(): embs = self.model.decoder.embeddings[f'embeddings_{tgt_lang}'] - onmt.utils.distributed.only_ready_reduce_and_rescale_grads(embs.named_parameters(), group=group) + mammoth.distributed.only_ready_reduce_and_rescale_grads(embs.named_parameters(), group=group) - onmt.utils.distributed.only_ready_reduce_and_rescale_grads( + mammoth.distributed.only_ready_reduce_and_rescale_grads( self.model.generator[f'generator_{tgt_lang}'].named_parameters(), group=group ) @@ -316,18 +316,18 @@ def train( adapter = self.model.encoder.get_submodule(layer_stack_index, encoder_id).get_adapter( adapter_group, sub_id ) - onmt.utils.distributed.only_ready_reduce_and_rescale_grads(adapter.named_parameters(), group=group) + mammoth.distributed.only_ready_reduce_and_rescale_grads(adapter.named_parameters(), group=group) for adapter_id, (_, group) in self.my_decoder_adapter_groups.items(): layer_stack_index, decoder_id, adapter_group, sub_id = adapter_id adapter = self.model.decoder.get_submodule(layer_stack_index, decoder_id).get_adapter( adapter_group, sub_id ) - onmt.utils.distributed.only_ready_reduce_and_rescale_grads(adapter.named_parameters(), group=group) + mammoth.distributed.only_ready_reduce_and_rescale_grads(adapter.named_parameters(), group=group) # a group is not specified: reduce across all devices if device_context.is_distributed(): - onmt.utils.distributed.only_ready_reduce_and_rescale_grads( + mammoth.distributed.only_ready_reduce_and_rescale_grads( self.model.attention_bridge.named_parameters() ) @@ -420,11 +420,11 @@ def validate(self, valid_iter, moving_average=None, task=None): # Tasks need not define validation paths: hence, a device need not contain # any validation path. This would cause statistics equals to 0 word seen, # which would then cause a zero devision when normalizing PPL per words. - stats = None # onmt.utils.Statistics() + stats = None # mammoth.utils.Statistics() for batch, metadata, _ in valid_iter: if stats is None: - stats = onmt.utils.Statistics() + stats = mammoth.utils.Statistics() src, src_lengths = batch.src if isinstance(batch.src, tuple) else (batch.src, None) tgt = batch.tgt @@ -542,14 +542,14 @@ def _maybe_gather_stats(self, stat): Gather statistics in multi-processes cases Args: - stat(:obj:onmt.utils.Statistics): a Statistics object to gather + stat(:obj:mammoth.utils.Statistics): a Statistics object to gather or None (it returns None in this case) Returns: stat: the updated (or unchanged) stat object """ if stat is not None and self.device_context.is_distributed(): - return onmt.utils.Statistics.all_gather_stats(stat) + return mammoth.utils.Statistics.all_gather_stats(stat) return stat def _maybe_update_stats_from_parameters(self, report_stats, named_parameters): @@ -559,7 +559,7 @@ def _maybe_update_stats_from_parameters(self, report_stats, named_parameters): def _maybe_report_training(self, step, num_steps, learning_rate, report_stats): """ Simple function to report training stats (if report_manager is set) - see `onmt.utils.ReportManagerBase.report_training` for doc + see `mammoth.utils.ReportManagerBase.report_training` for doc """ if self.report_manager is not None: return self.report_manager.report_training( @@ -574,7 +574,7 @@ def _maybe_report_training(self, step, num_steps, learning_rate, report_stats): def _report_step(self, learning_rate, step, train_stats=None, valid_stats=None): """ Simple function to report stats (if report_manager is set) - see `onmt.utils.ReportManagerBase.report_step` for doc + see `mammoth.utils.ReportManagerBase.report_step` for doc """ if self.report_manager is not None: return self.report_manager.report_step( diff --git a/onmt/transforms/__init__.py b/mammoth/transforms/__init__.py similarity index 95% rename from onmt/transforms/__init__.py rename to mammoth/transforms/__init__.py index 673f8383..b585e216 100644 --- a/onmt/transforms/__init__.py +++ b/mammoth/transforms/__init__.py @@ -49,4 +49,4 @@ def register_transfrom_cls(cls): path = os.path.join(transform_dir, file) if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)): file_name = file[: file.find('.py')] if file.endswith('.py') else file - module = importlib.import_module('onmt.transforms.' + file_name) + module = importlib.import_module('mammoth.transforms.' + file_name) diff --git a/onmt/transforms/denoising.py b/mammoth/transforms/denoising.py similarity index 99% rename from onmt/transforms/denoising.py rename to mammoth/transforms/denoising.py index 36fa8c44..f1fc9693 100644 --- a/onmt/transforms/denoising.py +++ b/mammoth/transforms/denoising.py @@ -4,8 +4,8 @@ import torch from typing import Sequence, Callable -from onmt.constants import DefaultTokens, SubwordMarker -from onmt.transforms import register_transform +from mammoth.constants import DefaultTokens, SubwordMarker +from mammoth.transforms import register_transform from .transform import Transform diff --git a/onmt/transforms/features.py b/mammoth/transforms/features.py similarity index 94% rename from onmt/transforms/features.py rename to mammoth/transforms/features.py index a6fd06c4..762cdcd4 100644 --- a/onmt/transforms/features.py +++ b/mammoth/transforms/features.py @@ -1,7 +1,7 @@ -from onmt.utils.logging import logger -from onmt.transforms import register_transform +from mammoth.utils.logging import logger +from mammoth.transforms import register_transform from .transform import Transform -from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer +from mammoth.utils.alignment import subword_map_by_joiner, subword_map_by_spacer import re from collections import defaultdict diff --git a/onmt/transforms/misc.py b/mammoth/transforms/misc.py similarity index 96% rename from onmt/transforms/misc.py rename to mammoth/transforms/misc.py index a7b8e1a0..b8c1e8b1 100644 --- a/onmt/transforms/misc.py +++ b/mammoth/transforms/misc.py @@ -1,5 +1,5 @@ -from onmt.utils.logging import logger -from onmt.transforms import register_transform +from mammoth.utils.logging import logger +from mammoth.transforms import register_transform from .transform import Transform, ObservableStats @@ -72,7 +72,7 @@ def _get_prefix(corpus): def get_prefix_dict(cls, opts): """Get all needed prefix correspond to corpus in `opts`.""" prefix_dict = {} - for c_name, corpus in opts.data.items(): + for c_name, corpus in opts.tasks.items(): prefix = cls._get_prefix(corpus) if prefix is not None: logger.info(f"Get prefix for {c_name}: {prefix}") diff --git a/onmt/transforms/sampling.py b/mammoth/transforms/sampling.py similarity index 98% rename from onmt/transforms/sampling.py rename to mammoth/transforms/sampling.py index c0fadea9..e2b55182 100644 --- a/onmt/transforms/sampling.py +++ b/mammoth/transforms/sampling.py @@ -1,8 +1,8 @@ """Transforms relate to hamming distance sampling.""" import random import numpy as np -from onmt.constants import DefaultTokens -from onmt.transforms import register_transform +from mammoth.constants import DefaultTokens +from mammoth.transforms import register_transform from .transform import Transform, ObservableStats diff --git a/onmt/transforms/tokenize.py b/mammoth/transforms/tokenize.py similarity index 99% rename from onmt/transforms/tokenize.py rename to mammoth/transforms/tokenize.py index bf4470d3..5c6e283a 100644 --- a/onmt/transforms/tokenize.py +++ b/mammoth/transforms/tokenize.py @@ -1,6 +1,6 @@ """Transforms relate to tokenization/subword.""" -from onmt.utils.logging import logger -from onmt.transforms import register_transform +from mammoth.utils.logging import logger +from mammoth.transforms import register_transform from .transform import Transform, ObservableStats diff --git a/onmt/transforms/transform.py b/mammoth/transforms/transform.py similarity index 99% rename from onmt/transforms/transform.py rename to mammoth/transforms/transform.py index 6238d3ae..e553f6a2 100644 --- a/onmt/transforms/transform.py +++ b/mammoth/transforms/transform.py @@ -1,7 +1,7 @@ """Base Transform class and relate utils.""" import torch -from onmt.utils.logging import logger -from onmt.utils.misc import check_path +from mammoth.utils.logging import logger +from mammoth.utils.misc import check_path class Transform(object): diff --git a/mammoth/translate/__init__.py b/mammoth/translate/__init__.py new file mode 100644 index 00000000..a48ea841 --- /dev/null +++ b/mammoth/translate/__init__.py @@ -0,0 +1,25 @@ +""" Modules for translation """ +from mammoth.translate.translator import Translator, GeneratorLM +from mammoth.translate.translation import Translation, TranslationBuilder +from mammoth.translate.beam_search import BeamSearch, GNMTGlobalScorer +from mammoth.translate.beam_search import BeamSearchLM +from mammoth.translate.decode_strategy import DecodeStrategy +from mammoth.translate.greedy_search import GreedySearch, GreedySearchLM +from mammoth.translate.penalties import PenaltyBuilder +from mammoth.translate.translation_server import TranslationServer, ServerModelError + +__all__ = [ + 'Translator', + 'Translation', + 'BeamSearch', + 'GNMTGlobalScorer', + 'TranslationBuilder', + 'PenaltyBuilder', + 'TranslationServer', + 'ServerModelError', + "DecodeStrategy", + "GreedySearch", + "GreedySearchLM", + "BeamSearchLM", + "GeneratorLM", +] diff --git a/onmt/translate/beam_search.py b/mammoth/translate/beam_search.py similarity index 98% rename from onmt/translate/beam_search.py rename to mammoth/translate/beam_search.py index cb60c298..c5741367 100644 --- a/onmt/translate/beam_search.py +++ b/mammoth/translate/beam_search.py @@ -1,6 +1,6 @@ import torch -from onmt.translate import penalties -from onmt.translate.decode_strategy import DecodeStrategy +from mammoth.translate import penalties +from mammoth.translate.decode_strategy import DecodeStrategy import warnings @@ -22,7 +22,7 @@ class BeamSearchBase(DecodeStrategy): unk (int): See base. n_best (int): Don't stop until at least this many beams have reached EOS. - global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance. + global_scorer (mammoth.translate.GNMTGlobalScorer): Scorer instance. min_length (int): See base. max_length (int): See base. return_attention (bool): See base. @@ -410,8 +410,8 @@ class GNMTGlobalScorer(object): """ @classmethod - def from_opt(cls, opt): - return cls(opt.alpha, opt.beta, opt.length_penalty, opt.coverage_penalty) + def from_opts(cls, opts): + return cls(opts.alpha, opts.beta, opts.length_penalty, opts.coverage_penalty) def __init__(self, alpha, beta, length_penalty, coverage_penalty): self._validate(alpha, beta, length_penalty, coverage_penalty) diff --git a/onmt/translate/decode_strategy.py b/mammoth/translate/decode_strategy.py similarity index 99% rename from onmt/translate/decode_strategy.py rename to mammoth/translate/decode_strategy.py index 0fd86906..cabf7539 100644 --- a/onmt/translate/decode_strategy.py +++ b/mammoth/translate/decode_strategy.py @@ -1,7 +1,7 @@ import torch from copy import deepcopy -from onmt.utils.misc import tile +from mammoth.utils.misc import tile class DecodeStrategy(object): diff --git a/onmt/translate/greedy_search.py b/mammoth/translate/greedy_search.py similarity index 96% rename from onmt/translate/greedy_search.py rename to mammoth/translate/greedy_search.py index 1631d379..91251b32 100644 --- a/onmt/translate/greedy_search.py +++ b/mammoth/translate/greedy_search.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F -from onmt.translate.decode_strategy import DecodeStrategy +from mammoth.translate.decode_strategy import DecodeStrategy def sample_topp(logits, keep_topp): @@ -98,7 +98,7 @@ class GreedySearch(DecodeStrategy): eos (int): See base. unk (int): See base. batch_size (int): See base. - global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance. + global_scorer (mammoth.translate.GNMTGlobalScorer): Scorer instance. min_length (int): See base. max_length (int): See base. ban_unk_token (Boolean): See base. @@ -107,11 +107,11 @@ class GreedySearch(DecodeStrategy): return_attention (bool): See base. max_length (int): See base. sampling_temp (float): See - :func:`~onmt.translate.greedy_search.sample_with_temperature()`. + :func:`~mammoth.translate.greedy_search.sample_with_temperature()`. keep_topk (int): See - :func:`~onmt.translate.greedy_search.sample_with_temperature()`. + :func:`~mammoth.translate.greedy_search.sample_with_temperature()`. keep_topp (float): See - :func:`~onmt.translate.greedy_search.sample_with_temperature()`. + :func:`~mammoth.translate.greedy_search.sample_with_temperature()`. beam_size (int): Number of beams to use. """ diff --git a/onmt/translate/penalties.py b/mammoth/translate/penalties.py similarity index 100% rename from onmt/translate/penalties.py rename to mammoth/translate/penalties.py diff --git a/onmt/translate/process_zh.py b/mammoth/translate/process_zh.py similarity index 100% rename from onmt/translate/process_zh.py rename to mammoth/translate/process_zh.py diff --git a/onmt/translate/translation.py b/mammoth/translate/translation.py similarity index 96% rename from onmt/translate/translation.py rename to mammoth/translate/translation.py index 8d2aebb9..3b985d00 100644 --- a/onmt/translate/translation.py +++ b/mammoth/translate/translation.py @@ -1,7 +1,7 @@ """ Translation main class """ import os -from onmt.constants import DefaultTokens -from onmt.utils.alignment import build_align_pharaoh +from mammoth.constants import DefaultTokens +from mammoth.utils.alignment import build_align_pharaoh # FIXME @@ -14,8 +14,8 @@ class TranslationBuilder(object): Problem in Neural Machine Translation" :cite:`Luong2015b` Args: - data (onmt.inputters.ParallelCorpus): Data. - vocabs (dict[str, onmt.inputters.Vocab]): data vocabs + data (mammoth.inputters.ParallelCorpus): Data. + vocabs (dict[str, mammoth.inputters.Vocab]): data vocabs n_best (int): number of translations produced replace_unk (bool): replace unknown words using attention has_tgt (bool): will the batch have gold targets diff --git a/onmt/translate/translation_server.py b/mammoth/translate/translation_server.py similarity index 89% rename from onmt/translate/translation_server.py rename to mammoth/translate/translation_server.py index 7184637d..be2e4043 100644 --- a/onmt/translate/translation_server.py +++ b/mammoth/translate/translation_server.py @@ -10,18 +10,18 @@ import traceback import importlib import torch -import onmt.opts +import mammoth.opts from itertools import islice, zip_longest from copy import deepcopy -from onmt.constants import DefaultTokens -from onmt.utils.logging import init_logger -from onmt.utils.misc import set_random_seed -from onmt.utils.misc import check_model_config -from onmt.utils.alignment import to_word_align -from onmt.utils.parse import ArgumentParser -from onmt.translate.translator import build_translator +from mammoth.constants import DefaultTokens +from mammoth.utils.logging import init_logger +from mammoth.utils.misc import set_random_seed +from mammoth.utils.misc import check_model_config +from mammoth.utils.alignment import to_word_align +from mammoth.utils.parse import ArgumentParser +from mammoth.translate.translator import build_translator def critical(func): @@ -78,7 +78,7 @@ class ServerModelError(Exception): class CTranslate2Translator(object): """ This class wraps the ctranslate2.Translator object to - reproduce the onmt.translate.translator API. + reproduce the mammoth.translate.translator API. """ def __init__(self, model_path, ct2_translator_args, ct2_translate_batch_args, target_prefix=False, preload=False): @@ -95,7 +95,7 @@ def __init__(self, model_path, ct2_translator_args, ct2_translate_batch_args, ta self.translator.unload_model(to_cpu=True) @staticmethod - def convert_onmt_to_ct2_opts(ct2_translator_args, ct2_translate_batch_args, opt): + def convert_onmt_to_ct2_opts(ct2_translator_args, ct2_translate_batch_args, opts): def setdefault_if_exists_must_match(obj, name, value): if name in obj: assert value == obj[name], ( @@ -115,18 +115,18 @@ def setdefault_if_exists_must_match(obj, name, value): ct2_translator_args.setdefault(name, value) onmt_for_translator = { - "device": "cuda" if opt.cuda else "cpu", - "device_index": opt.gpu if opt.cuda else 0, + "device": "cuda" if opts.cuda else "cpu", + "device_index": opts.gpu if opts.cuda else 0, } for name, value in onmt_for_translator.items(): setdefault_if_exists_must_match(ct2_translator_args, name, value) onmt_for_translate_batch_enforce = { - "beam_size": opt.beam_size, - "max_batch_size": opt.batch_size, - "num_hypotheses": opt.n_best, - "max_decoding_length": opt.max_length, - "min_decoding_length": opt.min_length, + "beam_size": opts.beam_size, + "max_batch_size": opts.batch_size, + "num_hypotheses": opts.n_best, + "max_decoding_length": opts.max_length, + "min_decoding_length": opts.min_length, } for name, value in onmt_for_translate_batch_enforce.items(): setdefault_if_exists_must_match(ct2_translate_batch_args, name, value) @@ -191,32 +191,32 @@ def start(self, config_file): } kwargs = {k: v for (k, v) in kwargs.items() if v is not None} model_id = conf.get("id", None) - opt = conf["opt"] - opt["models"] = conf["models"] - self.preload_model(opt, model_id=model_id, **kwargs) + opts = conf["opts"] + opts["models"] = conf["models"] + self.preload_model(opts, model_id=model_id, **kwargs) - def clone_model(self, model_id, opt, timeout=-1): + def clone_model(self, model_id, opts, timeout=-1): """Clone a model `model_id`. - Different options may be passed. If `opt` is None, it will use the + Different options may be passed. If `opts` is None, it will use the same set of options """ if model_id in self.models: - if opt is None: - opt = self.models[model_id].user_opt - opt["models"] = self.models[model_id].opt.models - return self.load_model(opt, timeout) + if opts is None: + opts = self.models[model_id].user_opt + opts["models"] = self.models[model_id].opts.models + return self.load_model(opts, timeout) else: raise ServerModelError("No such model '%s'" % str(model_id)) - def load_model(self, opt, model_id=None, **model_kwargs): + def load_model(self, opts, model_id=None, **model_kwargs): """Load a model given a set of options""" - model_id = self.preload_model(opt, model_id=model_id, **model_kwargs) + model_id = self.preload_model(opts, model_id=model_id, **model_kwargs) load_time = self.models[model_id].load_time return model_id, load_time - def preload_model(self, opt, model_id=None, **model_kwargs): + def preload_model(self, opts, model_id=None, **model_kwargs): """Preloading the model: updating internal datastructure It will effectively load the model if `load` is set @@ -230,7 +230,7 @@ def preload_model(self, opt, model_id=None, **model_kwargs): model_id += 1 self.next_id = model_id + 1 print("Pre-loading model %d" % model_id) - model = ServerModel(opt, model_id, **model_kwargs) + model = ServerModel(opts, model_id, **model_kwargs) self.models[model_id] = model return model_id @@ -274,7 +274,7 @@ class ServerModel(object): """Wrap a model with server functionality. Args: - opt (dict): Options for the Translator + opts (dict): Options for the Translator model_id (int): Model ID preprocess_opt (list): Options for preprocess processus or None tokenizer_opt (dict): Options for the tokenizer or None @@ -292,7 +292,7 @@ class ServerModel(object): def __init__( self, - opt, + opts, model_id, preprocess_opt=None, tokenizer_opt=None, @@ -307,7 +307,7 @@ def __init__( ct2_translate_batch_args=None, ): self.model_root = model_root - self.opt = self.parse_opt(opt) + self.opts = self.parse_opt(opts) self.custom_opt = custom_opt self.model_id = model_id @@ -322,20 +322,20 @@ def __init__( self.ct2_translate_batch_args = ct2_translate_batch_args self.unload_timer = None - self.user_opt = opt + self.user_opt = opts self.tokenizers = None - if len(self.opt.log_file) > 0: - log_file = os.path.join(model_root, self.opt.log_file) + if len(self.opts.log_file) > 0: + log_file = os.path.join(model_root, self.opts.log_file) else: log_file = None - self.logger = init_logger(log_file=log_file, log_file_level=self.opt.log_file_level, rotate=True) + self.logger = init_logger(log_file=log_file, log_file_level=self.opts.log_file_level, rotate=True) self.loading_lock = threading.Event() self.loading_lock.set() self.running_lock = threading.Semaphore(value=1) - set_random_seed(self.opt.seed, self.opt.cuda) + set_random_seed(self.opts.seed, self.opts.cuda) if self.preprocess_opt is not None: self.logger.info("Loading preprocessor") @@ -370,28 +370,28 @@ def __init__( self.load(preload=True) self.stop_unload_timer() - def parse_opt(self, opt): - """Parse the option set passed by the user using `onmt.opts` + def parse_opt(self, opts): + """Parse the option set passed by the user using `mammoth.opts` Args: - opt (dict): Options passed by the user + opts (dict): Options passed by the user Returns: - opt (argparse.Namespace): full set of options for the Translator + opts (argparse.Namespace): full set of options for the Translator """ prec_argv = sys.argv sys.argv = sys.argv[:1] parser = ArgumentParser() - onmt.opts.translate_opts(parser) + mammoth.opts.translate_opts(parser) - models = opt['models'] + models = opts['models'] if not isinstance(models, (list, tuple)): models = [models] - opt['models'] = [os.path.join(self.model_root, model) for model in models] - opt['src'] = "dummy_src" + opts['models'] = [os.path.join(self.model_root, model) for model in models] + opts['src'] = "dummy_src" - for (k, v) in opt.items(): + for (k, v) in opts.items(): if k == 'models': sys.argv += ['-model'] sys.argv += [str(model) for model in v] @@ -400,12 +400,12 @@ def parse_opt(self, opt): else: sys.argv += ['-%s' % k, str(v)] - opt = parser.parse_args() - ArgumentParser.validate_translate_opts(opt) - opt.cuda = opt.gpu > -1 + opts = parser.parse_args() + ArgumentParser.validate_translate_opts(opts) + opts.cuda = opts.gpu > -1 sys.argv = prec_argv - return opt + return opts @property def loaded(self): @@ -421,18 +421,18 @@ def load(self, preload=False): try: if self.ct2_model is not None: CTranslate2Translator.convert_onmt_to_ct2_opts( - self.ct2_translator_args, self.ct2_translate_batch_args, self.opt + self.ct2_translator_args, self.ct2_translate_batch_args, self.opts ) self.translator = CTranslate2Translator( self.ct2_model, ct2_translator_args=self.ct2_translator_args, ct2_translate_batch_args=self.ct2_translate_batch_args, - target_prefix=self.opt.tgt_prefix, + target_prefix=self.opts.tgt_prefix, preload=preload, ) else: self.translator = build_translator( - self.opt, report_score=False, out_file=codecs.open(os.devnull, "w", "utf-8") + self.opts, report_score=False, out_file=codecs.open(os.devnull, "w", "utf-8") ) except RuntimeError as e: raise ServerModelError("Runtime Error: %s" % str(e)) @@ -470,7 +470,7 @@ def run(self, inputs): if not self.loaded: self.load() timer.tick(name="load") - elif self.opt.cuda: + elif self.opts.cuda: self.to_gpu() timer.tick(name="to_gpu") @@ -517,14 +517,14 @@ def run(self, inputs): scores, predictions = self.translator.translate( texts_to_translate, tgt=texts_ref, - batch_size=len(texts_to_translate) if self.opt.batch_size == 0 else self.opt.batch_size, + batch_size=len(texts_to_translate) if self.opts.batch_size == 0 else self.opts.batch_size, ) except (RuntimeError, Exception) as e: err = "Error: %s" % str(e) self.logger.error(err) self.logger.error("repr(text_to_translate): " + repr(texts_to_translate)) self.logger.error("model: #%s" % self.model_id) - self.logger.error("model opt: " + str(self.opt.__dict__)) + self.logger.error("model opts: " + str(self.opts.__dict__)) self.logger.error(traceback.format_exc()) raise ServerModelError(err) @@ -541,7 +541,7 @@ def run(self, inputs): def flatten_list(_list): return sum(_list, []) - tiled_texts = [t for t in texts_to_translate for _ in range(self.opt.n_best)] + tiled_texts = [t for t in texts_to_translate for _ in range(self.opts.n_best)] results = flatten_list(predictions) def maybe_item(x): @@ -556,24 +556,24 @@ def maybe_item(x): # build back results with empty texts for i in empty_indices: - j = i * self.opt.n_best - results = results[:j] + [""] * self.opt.n_best + results[j:] - aligns = aligns[:j] + [None] * self.opt.n_best + aligns[j:] - scores = scores[:j] + [0] * self.opt.n_best + scores[j:] + j = i * self.opts.n_best + results = results[:j] + [""] * self.opts.n_best + results[j:] + aligns = aligns[:j] + [None] * self.opts.n_best + aligns[j:] + scores = scores[:j] + [0] * self.opts.n_best + scores[j:] rebuilt_segs, scores, aligns = self.rebuild_seg_packages( - all_preprocessed, results, scores, aligns, self.opt.n_best + all_preprocessed, results, scores, aligns, self.opts.n_best ) results = [self.maybe_postprocess(seg) for seg in rebuilt_segs] - head_spaces = [h for h in head_spaces for i in range(self.opt.n_best)] - tail_spaces = [h for h in tail_spaces for i in range(self.opt.n_best)] + head_spaces = [h for h in head_spaces for i in range(self.opts.n_best)] + tail_spaces = [h for h in tail_spaces for i in range(self.opts.n_best)] results = ["".join(items) for items in zip(head_spaces, results, tail_spaces)] self.logger.info("Translation Results: %d", len(results)) - return results, scores, self.opt.n_best, timer.times, aligns + return results, scores, self.opts.n_best, timer.times, aligns def rebuild_seg_packages(self, all_preprocessed, results, scores, aligns, n_best): """ @@ -618,7 +618,7 @@ def do_timeout(self): def unload(self): self.logger.info("Unloading model %d" % self.model_id) del self.translator - if self.opt.cuda: + if self.opts.cuda: torch.cuda.empty_cache() self.stop_unload_timer() self.unload_timer = None @@ -639,7 +639,7 @@ def to_dict(self): hide_opt = ["models", "src"] d = { "model_id": self.model_id, - "opt": {k: self.user_opt[k] for k in self.user_opt.keys() if k not in hide_opt}, + "opts": {k: self.user_opt[k] for k in self.user_opt.keys() if k not in hide_opt}, "models": self.user_opt["models"], "loaded": self.loaded, "timeout": self.timeout, @@ -655,7 +655,7 @@ def to_cpu(self): self.translator.to_cpu() else: self.translator.model.cpu() - if self.opt.cuda: + if self.opts.cuda: torch.cuda.empty_cache() def to_gpu(self): @@ -663,7 +663,7 @@ def to_gpu(self): if type(self.translator) == CTranslate2Translator: self.translator.to_gpu() else: - torch.cuda.set_device(self.opt.gpu) + torch.cuda.set_device(self.opts.gpu) self.translator.model.cuda() def maybe_preprocess(self, sequence): @@ -785,7 +785,7 @@ def maybe_detokenize_with_align(self, sequence, src, side='tgt'): sorted or None if no alignment in output. """ align = None - if self.opt.report_align: + if self.opts.report_align: # output contain alignment sequence, align = sequence.split(DefaultTokens.ALIGNMENT_SEPARATOR) if align != '': diff --git a/onmt/translate/translator.py b/mammoth/translate/translator.py similarity index 88% rename from onmt/translate/translator.py rename to mammoth/translate/translator.py index 44c9092c..8388ec22 100644 --- a/onmt/translate/translator.py +++ b/mammoth/translate/translator.py @@ -8,58 +8,58 @@ import torch -import onmt.model_builder -import onmt.decoders.ensemble -# from onmt.inputters.text_dataset import InferenceDataIterator -from onmt.translate.beam_search import BeamSearch, BeamSearchLM -from onmt.translate.greedy_search import GreedySearch, GreedySearchLM -from onmt.utils.misc import tile, set_random_seed, report_matrix -from onmt.utils.alignment import extract_alignment, build_align_pharaoh -from onmt.modules.copy_generator import collapse_copy_scores -from onmt.constants import ModelTask, DefaultTokens -from onmt.inputters.dataset import ParallelCorpus -from onmt.inputters.dataloader import build_dataloader - - -def build_translator(opt, task, report_score=True, logger=None, out_file=None): +import mammoth.model_builder +import mammoth.modules.decoder_ensemble +# from mammoth.inputters.text_dataset import InferenceDataIterator +from mammoth.translate.beam_search import BeamSearch, BeamSearchLM +from mammoth.translate.greedy_search import GreedySearch, GreedySearchLM +from mammoth.utils.misc import tile, set_random_seed, report_matrix +from mammoth.utils.alignment import extract_alignment, build_align_pharaoh +from mammoth.constants import ModelTask, DefaultTokens +from mammoth.inputters.dataset import ParallelCorpus +from mammoth.inputters.dataloader import build_dataloader + + +def build_translator(opts, task, report_score=True, logger=None, out_file=None): if out_file is None: - outdir = os.path.dirname(opt.output) + outdir = os.path.dirname(opts.output) if outdir and not os.path.isdir(outdir): # FIXME use warnings instead logger.info('WARNING: output file directory does not exist... creating it.') - os.makedirs(os.path.dirname(opt.output), exist_ok=True) - out_file = codecs.open(opt.output, "w+", "utf-8") + os.makedirs(os.path.dirname(opts.output), exist_ok=True) + out_file = codecs.open(opts.output, "w+", "utf-8") load_test_model = ( - onmt.decoders.ensemble.load_test_model if len(opt.models) > 3 else onmt.model_builder.load_test_multitask_model + mammoth.modules.decoder_ensemble.load_test_model if len(opts.models) > 3 + else mammoth.model_builder.load_test_multitask_model ) if logger: logger.info(str(task)) - vocabs, model, model_opt = load_test_model(opt) + vocabs, model, model_opts = load_test_model(opts) - scorer = onmt.translate.GNMTGlobalScorer.from_opt(opt) + scorer = mammoth.translate.GNMTGlobalScorer.from_opts(opts) - if model_opt.model_task == ModelTask.LANGUAGE_MODEL: - translator = GeneratorLM.from_opt( + if model_opts.model_task == ModelTask.LANGUAGE_MODEL: + translator = GeneratorLM.from_opts( model, vocabs, - opt, - model_opt, + opts, + model_opts, global_scorer=scorer, out_file=out_file, - report_align=opt.report_align, + report_align=opts.report_align, report_score=report_score, logger=logger, ) else: - translator = Translator.from_opt( + translator = Translator.from_opts( model, vocabs, - opt, - model_opt, + opts, + model_opts, global_scorer=scorer, out_file=out_file, - report_align=opt.report_align, + report_align=opts.report_align, report_score=report_score, logger=logger, task=task, @@ -90,36 +90,36 @@ class Inference(object): """Translate a batch of sentences with a saved model. Args: - model (onmt.modules.NMTModel): NMT model to use for translation - vocabs (dict[str, onmt.inputters.Vocab]): A dict + model (mammoth.modules.NMTModel): NMT model to use for translation + vocabs (dict[str, mammoth.inputters.Vocab]): A dict mapping each side to its Vocab. src_file_path (str): Source file to read. tgt_reader (src): Target file, if necessary. gpu (int): GPU device. Set to negative for no GPU. n_best (int): How many beams to wait for. min_length (int): See - :class:`onmt.translate.decode_strategy.DecodeStrategy`. + :class:`mammoth.translate.decode_strategy.DecodeStrategy`. max_length (int): See - :class:`onmt.translate.decode_strategy.DecodeStrategy`. + :class:`mammoth.translate.decode_strategy.DecodeStrategy`. beam_size (int): Number of beams. random_sampling_topk (int): See - :class:`onmt.translate.greedy_search.GreedySearch`. + :class:`mammoth.translate.greedy_search.GreedySearch`. random_sampling_temp (float): See - :class:`onmt.translate.greedy_search.GreedySearch`. + :class:`mammoth.translate.greedy_search.GreedySearch`. stepwise_penalty (bool): Whether coverage penalty is applied every step or not. dump_beam (bool): Debugging option. block_ngram_repeat (int): See - :class:`onmt.translate.decode_strategy.DecodeStrategy`. + :class:`mammoth.translate.decode_strategy.DecodeStrategy`. ignore_when_blocking (set or frozenset): See - :class:`onmt.translate.decode_strategy.DecodeStrategy`. + :class:`mammoth.translate.decode_strategy.DecodeStrategy`. replace_unk (bool): Replace unknown token. tgt_prefix (bool): Force the predictions begin with provided -tgt. data_type (str): Source data type. verbose (bool): Print/log every translation. report_time (bool): Print/log total time/frequency. copy_attn (bool): Use copy attention. - global_scorer (onmt.translate.GNMTGlobalScorer): Translation + global_scorer (mammoth.translate.GNMTGlobalScorer): Translation scoring/reranking object. out_file (TextIO or codecs.StreamReaderWriter): Output file. report_score (bool) : Whether to report scores @@ -236,12 +236,12 @@ def __init__( set_random_seed(seed, self._use_cuda) @classmethod - def from_opt( + def from_opts( cls, model, vocabs, - opt, - model_opt, + opts, + model_opts, global_scorer=None, out_file=None, report_align=False, @@ -252,13 +252,13 @@ def from_opt( """Alternate constructor. Args: - model (onmt.modules.NMTModel): See :func:`__init__()`. - vocabs (dict[str, onmt.inputters.Vocab]): See + model (mammoth.modules.NMTModel): See :func:`__init__()`. + vocabs (dict[str, mammoth.inputters.Vocab]): See :func:`__init__()`. - opt (argparse.Namespace): Command line options - model_opt (argparse.Namespace): Command line options saved with + opts (argparse.Namespace): Command line options + model_opts (argparse.Namespace): Command line options saved with the model checkpoint. - global_scorer (onmt.translate.GNMTGlobalScorer): See + global_scorer (mammoth.translate.GNMTGlobalScorer): See :func:`__init__()`.. out_file (TextIO or codecs.StreamReaderWriter): See :func:`__init__()`. @@ -268,40 +268,40 @@ def from_opt( """ assert task is not None # TODO: maybe add dynamic part - cls.validate_task(model_opt.model_task) + cls.validate_task(model_opts.model_task) return cls( model, vocabs, - opt.src, - tgt_file_path=opt.tgt, - gpu=opt.gpu, - n_best=opt.n_best, - min_length=opt.min_length, - max_length=opt.max_length, - ratio=opt.ratio, - beam_size=opt.beam_size, - random_sampling_topk=opt.random_sampling_topk, - random_sampling_topp=opt.random_sampling_topp, - random_sampling_temp=opt.random_sampling_temp, - stepwise_penalty=opt.stepwise_penalty, - dump_beam=opt.dump_beam, - block_ngram_repeat=opt.block_ngram_repeat, - ignore_when_blocking=set(opt.ignore_when_blocking), - replace_unk=opt.replace_unk, - ban_unk_token=opt.ban_unk_token, - tgt_prefix=opt.tgt_prefix, - phrase_table=opt.phrase_table, - data_type=opt.data_type, - verbose=opt.verbose, - report_time=opt.report_time, - copy_attn=model_opt.copy_attn, + opts.src, + tgt_file_path=opts.tgt, + gpu=opts.gpu, + n_best=opts.n_best, + min_length=opts.min_length, + max_length=opts.max_length, + ratio=opts.ratio, + beam_size=opts.beam_size, + random_sampling_topk=opts.random_sampling_topk, + random_sampling_topp=opts.random_sampling_topp, + random_sampling_temp=opts.random_sampling_temp, + stepwise_penalty=opts.stepwise_penalty, + dump_beam=opts.dump_beam, + block_ngram_repeat=opts.block_ngram_repeat, + ignore_when_blocking=set(opts.ignore_when_blocking), + replace_unk=opts.replace_unk, + ban_unk_token=opts.ban_unk_token, + tgt_prefix=opts.tgt_prefix, + phrase_table=opts.phrase_table, + data_type=opts.data_type, + verbose=opts.verbose, + report_time=opts.report_time, + copy_attn=model_opts.copy_attn, global_scorer=global_scorer, out_file=out_file, report_align=report_align, report_score=report_score, logger=logger, - seed=opt.seed, + seed=opts.seed, task=task, ) @@ -499,7 +499,7 @@ def _translate( # ) # data_iter = None - xlation_builder = onmt.translate.TranslationBuilder( + xlation_builder = mammoth.translate.TranslationBuilder( corpus, self.vocabs, self.n_best, @@ -669,39 +669,14 @@ def _decode_and_generate( ) # Generator forward. - if not self.copy_attn: - if "std" in dec_attn: - attn = dec_attn["std"] - else: - attn = None - log_probs = self.model.generator[f"generator_{self.task.tgt_lang}"](dec_out.squeeze(0)) - # returns [(batch_size x beam_size) , vocab ] when 1 step - # or [ tgt_len, batch_size, vocab ] when full sentence + if "std" in dec_attn: + attn = dec_attn["std"] else: - attn = dec_attn["copy"] - scores = self.model.generator[f"generator_{self.task.tgt_lang}"]( - dec_out.view(-1, dec_out.size(2)), - attn.view(-1, attn.size(2)), - src_map, - ) - # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] - if batch_offset is None: - scores = scores.view(-1, batch.batch_size, scores.size(-1)) - scores = scores.transpose(0, 1).contiguous() - else: - scores = scores.view(-1, self.beam_size, scores.size(-1)) - scores = collapse_copy_scores( - scores, - batch, - self._tgt_vocab, - src_vocabs, - batch_dim=0, - batch_offset=batch_offset, - ) - scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) - log_probs = scores.squeeze(0).log() - # returns [(batch_size x beam_size) , vocab ] when 1 step - # or [ tgt_len, batch_size, vocab ] when full sentence + attn = None + log_probs = self.model.generator[f"generator_{self.task.tgt_lang}"](dec_out.squeeze(0)) + # returns [(batch_size x beam_size) , vocab ] when 1 step + # or [ tgt_len, batch_size, vocab ] when full sentence + return log_probs, attn def translate_batch(self, batch, src_vocabs, attn_debug): diff --git a/mammoth/utils/__init__.py b/mammoth/utils/__init__.py new file mode 100644 index 00000000..49933156 --- /dev/null +++ b/mammoth/utils/__init__.py @@ -0,0 +1,25 @@ +"""Module defining various utilities.""" +from mammoth.utils.misc import split_corpus, aeq, use_gpu, set_random_seed +from mammoth.utils.alignment import make_batch_align_matrix +from mammoth.utils.report_manager import ReportMgr, build_report_manager +from mammoth.utils.statistics import Statistics +from mammoth.utils.optimizers import MultipleOptimizer, Optimizer, AdaFactorFairSeq +from mammoth.utils.earlystopping import EarlyStopping, scorers_from_opts +from mammoth.utils.loss import build_loss_compute + +__all__ = [ + "split_corpus", + "aeq", + "use_gpu", + "set_random_seed", + "ReportMgr", + "build_report_manager", + "Statistics", + "MultipleOptimizer", + "Optimizer", + "AdaFactorFairSeq", + "EarlyStopping", + "scorers_from_opts", + "make_batch_align_matrix", + "build_loss_compute", +] diff --git a/onmt/utils/alignment.py b/mammoth/utils/alignment.py similarity index 99% rename from onmt/utils/alignment.py rename to mammoth/utils/alignment.py index f58761b9..b7a1e6b4 100644 --- a/onmt/utils/alignment.py +++ b/mammoth/utils/alignment.py @@ -2,7 +2,7 @@ import torch from itertools import accumulate -from onmt.constants import SubwordMarker +from mammoth.constants import SubwordMarker def make_batch_align_matrix(index_tensor, size=None, normalize=False): diff --git a/onmt/utils/earlystopping.py b/mammoth/utils/earlystopping.py similarity index 96% rename from onmt/utils/earlystopping.py rename to mammoth/utils/earlystopping.py index 4244cf72..6d20c60f 100644 --- a/onmt/utils/earlystopping.py +++ b/mammoth/utils/earlystopping.py @@ -1,5 +1,5 @@ from enum import Enum -from onmt.utils.logging import logger +from mammoth.utils.logging import logger class PatienceEnum(Enum): @@ -63,12 +63,12 @@ def _caller(self, stats): SCORER_BUILDER = {"ppl": PPLScorer, "accuracy": AccuracyScorer} -def scorers_from_opts(opt): - if opt.early_stopping_criteria is None: +def scorers_from_opts(opts): + if opts.early_stopping_criteria is None: return DEFAULT_SCORERS else: scorers = [] - for criterion in set(opt.early_stopping_criteria): + for criterion in set(opts.early_stopping_criteria): assert criterion in SCORER_BUILDER.keys(), "Criterion {} not found".format(criterion) scorers.append(SCORER_BUILDER[criterion]()) return scorers diff --git a/onmt/utils/logging.py b/mammoth/utils/logging.py similarity index 100% rename from onmt/utils/logging.py rename to mammoth/utils/logging.py diff --git a/onmt/utils/loss.py b/mammoth/utils/loss.py similarity index 86% rename from onmt/utils/loss.py rename to mammoth/utils/loss.py index 100e7eb6..5061325d 100644 --- a/onmt/utils/loss.py +++ b/mammoth/utils/loss.py @@ -6,13 +6,11 @@ import torch.nn as nn import torch.nn.functional as F -import onmt -from onmt.modules.sparse_losses import SparsemaxLoss -from onmt.modules.sparse_activations import LogSparsemax -from onmt.constants import ModelTask, DefaultTokens +import mammoth +from mammoth.constants import ModelTask, DefaultTokens -def build_loss_compute(model, tgt_vocab, opt, train=True, generator=None): +def build_loss_compute(model, tgt_vocab, opts, train=True, generator=None): """ Returns a LossCompute subclass which wraps around an nn.Module subclass (such as nn.NLLLoss) which defines the loss criterion. The LossCompute @@ -21,62 +19,60 @@ def build_loss_compute(model, tgt_vocab, opt, train=True, generator=None): Currently, the NMTLossCompute class handles all loss computation except for when using a copy mechanism. """ - device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") + device = torch.device("cuda" if mammoth.utils.misc.use_gpu(opts) else "cpu") padding_idx = tgt_vocab.stoi[DefaultTokens.PAD] unk_idx = tgt_vocab.stoi[DefaultTokens.UNK] - if opt.lambda_coverage != 0: - assert opt.coverage_attn, "--coverage_attn needs to be set in order to use --lambda_coverage != 0" + if opts.lambda_coverage != 0: + assert opts.coverage_attn, "--coverage_attn needs to be set in order to use --lambda_coverage != 0" - if opt.copy_attn: - criterion = onmt.modules.CopyGeneratorLoss( - len(tgt_vocab), opt.copy_attn_force, unk_index=unk_idx, ignore_index=padding_idx + if opts.copy_attn: + criterion = mammoth.modules.CopyGeneratorLoss( + len(tgt_vocab), opts.copy_attn_force, unk_index=unk_idx, ignore_index=padding_idx ) - elif opt.label_smoothing > 0 and train: - criterion = LabelSmoothingLoss(opt.label_smoothing, len(tgt_vocab), ignore_index=padding_idx) - elif isinstance(generator[-1], LogSparsemax): # elif isinstance(model.generator[-1], LogSparsemax): - criterion = SparsemaxLoss(ignore_index=padding_idx, reduction='sum') + elif opts.label_smoothing > 0 and train: + criterion = LabelSmoothingLoss(opts.label_smoothing, len(tgt_vocab), ignore_index=padding_idx) else: criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum') # if the loss function operates on vectors of raw logits instead of # probabilities, only the first part of the generator needs to be - # passed to the NMTLossCompute. At the moment, the only supported - # loss function of this kind is the sparsemax loss. - use_raw_logits = isinstance(criterion, SparsemaxLoss) + # passed to the NMTLossCompute. At the moment, there is no supported + # loss function of this kind. + use_raw_logits = False loss_gen = ( generator[0] if use_raw_logits else generator ) # loss_gen = model.generator[0] if use_raw_logits else model.generator - if opt.copy_attn: - if opt.model_task == ModelTask.SEQ2SEQ: - compute = onmt.modules.CopyGeneratorLossCompute( - criterion, loss_gen, tgt_vocab, opt.copy_loss_by_seqlength, lambda_coverage=opt.lambda_coverage + if opts.copy_attn: + if opts.model_task == ModelTask.SEQ2SEQ: + compute = mammoth.modules.CopyGeneratorLossCompute( + criterion, loss_gen, tgt_vocab, opts.copy_loss_by_seqlength, lambda_coverage=opts.lambda_coverage ) - elif opt.model_task == ModelTask.LANGUAGE_MODEL: - compute = onmt.modules.CopyGeneratorLMLossCompute( - criterion, loss_gen, tgt_vocab, opt.copy_loss_by_seqlength, lambda_coverage=opt.lambda_coverage + elif opts.model_task == ModelTask.LANGUAGE_MODEL: + compute = mammoth.modules.CopyGeneratorLMLossCompute( + criterion, loss_gen, tgt_vocab, opts.copy_loss_by_seqlength, lambda_coverage=opts.lambda_coverage ) else: - raise ValueError(f"No copy generator loss defined for task {opt.model_task}") + raise ValueError(f"No copy generator loss defined for task {opts.model_task}") else: - if opt.model_task == ModelTask.SEQ2SEQ: + if opts.model_task == ModelTask.SEQ2SEQ: compute = NMTLossCompute( criterion, loss_gen, - lambda_coverage=opt.lambda_coverage, - lambda_align=opt.lambda_align, + lambda_coverage=opts.lambda_coverage, + lambda_align=opts.lambda_align, ) - elif opt.model_task == ModelTask.LANGUAGE_MODEL: - assert opt.lambda_align == 0.0, "lamdba_align not supported in LM loss" + elif opts.model_task == ModelTask.LANGUAGE_MODEL: + assert opts.lambda_align == 0.0, "lamdba_align not supported in LM loss" compute = LMLossCompute( criterion, loss_gen, - lambda_coverage=opt.lambda_coverage, - lambda_align=opt.lambda_align, + lambda_coverage=opts.lambda_coverage, + lambda_align=opts.lambda_align, ) else: - raise ValueError(f"No compute loss defined for task {opt.model_task}") + raise ValueError(f"No compute loss defined for task {opts.model_task}") compute.to(device) return compute @@ -163,7 +159,7 @@ def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_ trunc_size (int) : length of truncation window Returns: - A tuple with the loss and a :obj:`onmt.utils.Statistics` instance. + A tuple with the loss and a :obj:`mammoth.utils.Statistics` instance. """ if trunc_size is None: trunc_size = batch.tgt.size(0) - trunc_start @@ -172,7 +168,7 @@ def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_ if shard_size == 0: loss, stats = self._compute_loss(batch, **shard_state) return loss / float(normalization), stats - batch_stats = onmt.utils.Statistics() + batch_stats = mammoth.utils.Statistics() for shard in shards(shard_state, shard_size): loss, stats = self._compute_loss(batch, **shard) loss.div(float(normalization)).backward() # retain_graph=True) @@ -187,13 +183,13 @@ def _stats(self, loss, scores, labels): labels (:obj:`FloatTensor`): true targets Returns: - :obj:`onmt.utils.Statistics` : statistics for this batch. + :obj:`mammoth.utils.Statistics` : statistics for this batch. """ pred = scores.max(1)[1] non_padding = labels.ne(self.padding_idx) num_correct = pred.eq(labels).masked_select(non_padding).sum().item() num_non_padding = non_padding.sum().item() - return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct) + return mammoth.utils.Statistics(loss.item(), num_non_padding, num_correct) def _bottle(self, _v): return _v.view(-1, _v.size(2)) @@ -304,7 +300,7 @@ def _add_align_shard_state(self, shard_state, batch, range_start, range_end, att pad_tgt_size, batch_size, _ = batch.tgt.size() pad_src_size = batch.src[0].size(0) align_matrix_size = [batch_size, pad_tgt_size, pad_src_size] - ref_align = onmt.utils.make_batch_align_matrix(align_idx, align_matrix_size, normalize=True) + ref_align = mammoth.utils.make_batch_align_matrix(align_idx, align_matrix_size, normalize=True) # NOTE: tgt-src ref alignement that in range_ of shard # (coherent with batch.tgt) shard_state.update( diff --git a/onmt/utils/misc.py b/mammoth/utils/misc.py similarity index 97% rename from onmt/utils/misc.py rename to mammoth/utils/misc.py index 36cd5b82..280932f6 100644 --- a/onmt/utils/misc.py +++ b/mammoth/utils/misc.py @@ -79,11 +79,11 @@ def tile(x, count, dim=0): return x -def use_gpu(opt): +def use_gpu(opts): """ Creates a boolean if gpu used """ - return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or (hasattr(opt, 'gpu') and opt.gpu > -1) + return (hasattr(opts, 'gpu_ranks') and len(opts.gpu_ranks) > 0) or (hasattr(opts, 'gpu') and opts.gpu > -1) def set_random_seed(seed, is_cuda): diff --git a/onmt/utils/module_splitter.py b/mammoth/utils/module_splitter.py similarity index 98% rename from onmt/utils/module_splitter.py rename to mammoth/utils/module_splitter.py index 738037be..6e190a3a 100644 --- a/onmt/utils/module_splitter.py +++ b/mammoth/utils/module_splitter.py @@ -62,7 +62,7 @@ def explode_model(full_ab_model): # stuff necessary to build bilingual models combining modules model_frame = { "vocab": full_ab_model["vocab"], - "opt": full_ab_model["opt"], + "opts": full_ab_model["opts"], "optim": full_ab_model["optim"], } diff --git a/onmt/utils/optimizers.py b/mammoth/utils/optimizers.py similarity index 90% rename from onmt/utils/optimizers.py rename to mammoth/utils/optimizers.py index c6043683..57e20ebf 100644 --- a/onmt/utils/optimizers.py +++ b/mammoth/utils/optimizers.py @@ -7,7 +7,7 @@ import types from collections import Counter from math import sqrt -from onmt.utils.misc import fn_args +from mammoth.utils.misc import fn_args from torch.nn.utils import clip_grad_norm_ @@ -60,7 +60,7 @@ def attention_bridge_optimizer(model, task_queue_manager, base_optimizer): return optimizer -def build_torch_optimizer(model, opt, task_queue_manager): +def build_torch_optimizer(model, opts, task_queue_manager): """Builds the PyTorch optimizer. We use the default parameters for Adam that are suggested by @@ -76,116 +76,105 @@ def build_torch_optimizer(model, opt, task_queue_manager): Args: model: The model to optimize. - opt. The dictionary of options. + opts. The dictionary of options. Returns: A ``torch.optim.Optimizer`` instance. """ params = [p for p in model.parameters() if p.requires_grad] - betas = [opt.adam_beta1, opt.adam_beta2] - if opt.optim == 'sgd': - optimizer = optim.SGD(params, lr=opt.learning_rate) - elif opt.optim == 'adagrad': - optimizer = optim.Adagrad(params, lr=opt.learning_rate, initial_accumulator_value=opt.adagrad_accumulator_init) - elif opt.optim == 'adadelta': - optimizer = optim.Adadelta(params, lr=opt.learning_rate) - elif opt.optim == 'adafactor': + betas = [opts.adam_beta1, opts.adam_beta2] + if opts.optim == 'sgd': + optimizer = optim.SGD(params, lr=opts.learning_rate) + elif opts.optim == 'adagrad': + optimizer = optim.Adagrad( + params, + lr=opts.learning_rate, + initial_accumulator_value=opts.adagrad_accumulator_init, + ) + elif opts.optim == 'adadelta': + optimizer = optim.Adadelta(params, lr=opts.learning_rate) + elif opts.optim == 'adafactor': optimizer = attention_bridge_optimizer( model, task_queue_manager, - lambda params: AdaFactorFairSeq(params, weight_decay=opt.weight_decay), + lambda params: AdaFactorFairSeq(params, weight_decay=opts.weight_decay), ) - elif opt.optim == 'adam': + elif opts.optim == 'adam': optimizer = attention_bridge_optimizer( model, task_queue_manager, lambda params: optim.Adam( - params, lr=opt.learning_rate, betas=betas, eps=1e-9, weight_decay=opt.weight_decay + params, lr=opts.learning_rate, betas=betas, eps=1e-9, weight_decay=opts.weight_decay ) ) - elif opt.optim == 'adamw': + elif opts.optim == 'adamw': optimizer = attention_bridge_optimizer( model, task_queue_manager, lambda params: optim.AdamW( - params, lr=opt.learning_rate, betas=betas, eps=1e-9, weight_decay=opt.weight_decay + params, lr=opts.learning_rate, betas=betas, eps=1e-9, weight_decay=opts.weight_decay ) ) - elif opt.optim == 'sparseadam': - encs = [] - decs = [] - for name, param in model.named_parameters(): - if not param.requires_grad: - continue - # TODO: Find a better way to check for sparse gradients. - if 'decoder' in name: - # print(name) - decs.append(param) - else: - encs.append(param) - optimizer = MultipleOptimizer( - [optim.Adam(encs, lr=opt.learning_rate, betas=betas, eps=1e-9), AdaFactorFairSeq(decs, warmup_init=True)] - ) - elif opt.optim == 'fusedadam': + elif opts.optim == 'fusedadam': # we use here a FusedAdam() copy of an old Apex repo - optimizer = FusedAdam(params, lr=opt.learning_rate, betas=betas) - if opt.model_dtype == 'fp16': + optimizer = FusedAdam(params, lr=opts.learning_rate, betas=betas) + if opts.model_dtype == 'fp16': import apex # In this case use the old FusedAdam with FP16_optimizer wrapper - static_loss_scale = opt.loss_scale - dynamic_loss_scale = opt.loss_scale == 0 + static_loss_scale = opts.loss_scale + dynamic_loss_scale = opts.loss_scale == 0 optimizer = apex.contrib.optimizers.FP16_Optimizer( optimizer, static_loss_scale=static_loss_scale, dynamic_loss_scale=dynamic_loss_scale ) else: - raise ValueError('Invalid optimizer type: ' + opt.optim) + raise ValueError('Invalid optimizer type: ' + opts.optim) return optimizer -def make_learning_rate_decay_fn(opt): +def make_learning_rate_decay_fn(opts): """Returns the learning decay function from options.""" - if opt.decay_method == 'noam': - return functools.partial(noam_decay, warmup_steps=opt.warmup_steps, model_size=opt.rnn_size) - elif opt.decay_method == 'noamwd': + if opts.decay_method == 'noam': + return functools.partial(noam_decay, warmup_steps=opts.warmup_steps, model_dim=opts.model_dim) + elif opts.decay_method == 'noamwd': return functools.partial( noamwd_decay, - warmup_steps=opt.warmup_steps, - model_size=opt.rnn_size, - rate=opt.learning_rate_decay, - decay_steps=opt.decay_steps, - start_step=opt.start_decay_steps, + warmup_steps=opts.warmup_steps, + model_dim=opts.model_dim, + rate=opts.learning_rate_decay, + decay_steps=opts.decay_steps, + start_step=opts.start_decay_steps, ) - elif opt.decay_method == 'rsqrt': - return functools.partial(rsqrt_decay, warmup_steps=opt.warmup_steps) - elif opt.decay_method == 'linear_warmup': + elif opts.decay_method == 'rsqrt': + return functools.partial(rsqrt_decay, warmup_steps=opts.warmup_steps) + elif opts.decay_method == 'linear_warmup': return functools.partial( linear_warmup_decay, - warmup_steps=opt.warmup_steps, - rate=opt.learning_rate, - train_steps=opt.train_steps, + warmup_steps=opts.warmup_steps, + rate=opts.learning_rate, + train_steps=opts.train_steps, ) - elif opt.start_decay_steps is not None: + elif opts.start_decay_steps is not None: return functools.partial( exponential_decay, - rate=opt.learning_rate_decay, - decay_steps=opt.decay_steps, - start_step=opt.start_decay_steps, + rate=opts.learning_rate_decay, + decay_steps=opts.decay_steps, + start_step=opts.start_decay_steps, ) -def noam_decay(step, warmup_steps, model_size): +def noam_decay(step, warmup_steps, model_dim): """Learning rate schedule described in https://arxiv.org/pdf/1706.03762.pdf. """ - return model_size ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5)) + return model_dim ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5)) -def noamwd_decay(step, warmup_steps, model_size, rate, decay_steps, start_step=0): +def noamwd_decay(step, warmup_steps, model_dim, rate, decay_steps, start_step=0): """Learning rate schedule optimized for huge batches""" return ( - model_size ** (-0.5) + model_dim ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5)) * rate ** (max(step - start_step + decay_steps, 0) // decay_steps) ) @@ -212,7 +201,7 @@ def linear_warmup_decay(step, warmup_steps, rate, train_steps): class MultipleOptimizer(object): - """Implement multiple optimizers needed for sparse adam""" + """Implement multiple optimizers""" def __init__(self, op, multiOptims_Langs=None): self.optimizers = op @@ -290,24 +279,24 @@ def __init__(self, optimizer, learning_rate, learning_rate_decay_fn=None, max_gr self._scaler = None @classmethod - def from_opt(cls, model, opt, task_queue_manager, checkpoint=None): + def from_opts(cls, model, opts, task_queue_manager, checkpoint=None): """Builds the optimizer from options. Args: cls: The ``Optimizer`` class to instantiate. model: The model to optimize. - opt: The dict of user options. + opts: The dict of user options. checkpoint: An optional checkpoint to load states from. Returns: An ``Optimizer`` instance. """ - optim_opt = opt + optim_opt = opts optim_state_dict = None - if opt.train_from and checkpoint is not None: + if opts.train_from and checkpoint is not None: optim = checkpoint['optim'] - ckpt_opt = checkpoint['opt'] + ckpt_opt = checkpoint['opts'] ckpt_state_dict = {} if isinstance(optim, Optimizer): # Backward compatibility. ckpt_state_dict['training_step'] = optim._step + 1 @@ -316,19 +305,19 @@ def from_opt(cls, model, opt, task_queue_manager, checkpoint=None): else: ckpt_state_dict = optim - if opt.reset_optim == 'none': + if opts.reset_optim == 'none': # Load everything from the checkpoint. optim_opt = ckpt_opt optim_state_dict = ckpt_state_dict - elif opt.reset_optim == 'all': + elif opts.reset_optim == 'all': # Build everything from scratch. pass - elif opt.reset_optim == 'states': + elif opts.reset_optim == 'states': # Reset optimizer, keep options. optim_opt = ckpt_opt optim_state_dict = ckpt_state_dict del optim_state_dict['optimizer'] - elif opt.reset_optim == 'keep_states': + elif opts.reset_optim == 'keep_states': # Reset options, keep optimizer. optim_state_dict = ckpt_state_dict @@ -339,8 +328,8 @@ def from_opt(cls, model, opt, task_queue_manager, checkpoint=None): max_grad_norm=optim_opt.max_grad_norm, ) - if opt.model_dtype == "fp16": - if opt.optim == "fusedadam": + if opts.model_dtype == "fp16": + if opts.optim == "fusedadam": optimizer._fp16 = "legacy" else: optimizer._fp16 = "amp" @@ -734,11 +723,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=1.0, grad_nor if grad is None: grad = p.grad.data if grad.is_sparse: - raise RuntimeError( - 'FusedAdam does not support sparse \ - gradients, please consider \ - SparseAdam instead' - ) + raise RuntimeError('sparse gradient not supported') state = self.state[p] diff --git a/onmt/utils/parse.py b/mammoth/utils/parse.py similarity index 61% rename from onmt/utils/parse.py rename to mammoth/utils/parse.py index e979e338..54056cf9 100644 --- a/onmt/utils/parse.py +++ b/mammoth/utils/parse.py @@ -4,10 +4,10 @@ import torch import yaml -import onmt.opts as opts -from onmt.utils.logging import logger -from onmt.constants import CorpusName, ModelTask -from onmt.transforms import AVAILABLE_TRANSFORMS +import mammoth.opts as opts +from mammoth.utils.logging import logger +from mammoth.constants import CorpusName, ModelTask +from mammoth.transforms import AVAILABLE_TRANSFORMS RE_NODE_GPU = re.compile(r'\d+:\d+') RE_SRC_TGT = re.compile(r'[^-]+-[^-]+') @@ -23,21 +23,21 @@ def _validate_file(file_path, info): raise IOError(f"Please check path of your {info} file! {file_path}") @classmethod - def _validate_adapters(cls, opt): + def _validate_adapters(cls, opts): """Parse corpora specified in data field of YAML file.""" - if not opt.adapters: + if not opts.adapters: return - adapter_opts = yaml.safe_load(opt.adapters) + adapter_opts = yaml.safe_load(opts.adapters) # TODO: validate adapter opts - opt.adapters = adapter_opts + opts.adapters = adapter_opts @classmethod - def _validate_data(cls, opt): + def _validate_data(cls, opts): """Parse tasks/language-pairs/corpora specified in data field of YAML file.""" - default_transforms = opt.transforms + default_transforms = opts.transforms if len(default_transforms) != 0: logger.info(f"Default transforms: {default_transforms}.") - corpora = yaml.safe_load(opt.data) + corpora = yaml.safe_load(opts.tasks) logger.info("Parsing corpora") n_without_node_gpu = 0 for cname, corpus in corpora.items(): @@ -47,7 +47,7 @@ def _validate_data(cls, opt): if _transforms is None: logger.info(f"Missing transforms field for {cname} data, set to default: {default_transforms}.") corpus['transforms'] = default_transforms - opt.data_task = ModelTask.SEQ2SEQ + opts.data_task = ModelTask.SEQ2SEQ """ # Check path path_src = corpus.get('path_src', None) @@ -57,13 +57,13 @@ def _validate_data(cls, opt): 'tgt path is also required for non language' ' modeling tasks.') else: - opt.data_task = ModelTask.SEQ2SEQ + opts.data_task = ModelTask.SEQ2SEQ if path_tgt is None: logger.warning( "path_tgt is None, it should be set unless the task" " is language modeling" ) - opt.data_task = ModelTask.LANGUAGE_MODEL + opts.data_task = ModelTask.LANGUAGE_MODEL # tgt is src for LM task corpus["path_tgt"] = path_src corpora[cname] = corpus @@ -73,7 +73,7 @@ def _validate_data(cls, opt): """ path_align = corpus.get('path_align', None) if path_align is None: - if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0: + if hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0: raise ValueError(f'Corpus {cname} alignment file path are required when lambda_align > 0.0') corpus['path_align'] = None else: @@ -140,111 +140,111 @@ def _validate_data(cls, opt): assert n_without_node_gpu == 0 or n_without_node_gpu == len(corpora) logger.info(f"Parsed {len(corpora)} corpora from -data.") - opt.data = corpora + opts.tasks = corpora - src_vocab = yaml.safe_load(opt.src_vocab) + src_vocab = yaml.safe_load(opts.src_vocab) logger.info(f"Parsed {len(src_vocab)} vocabs from -src_vocab.") - opt.src_vocab = src_vocab + opts.src_vocab = src_vocab - tgt_vocab = yaml.safe_load(opt.tgt_vocab) + tgt_vocab = yaml.safe_load(opts.tgt_vocab) logger.info(f"Parsed {len(tgt_vocab)} vocabs from -tgt_vocab.") - opt.tgt_vocab = tgt_vocab + opts.tgt_vocab = tgt_vocab @classmethod - def _validate_transforms_opts(cls, opt): + def _validate_transforms_opts(cls, opts): """Check options used by transforms.""" for name, transform_cls in AVAILABLE_TRANSFORMS.items(): - if name in opt._all_transform: - transform_cls._validate_options(opt) + if name in opts._all_transform: + transform_cls._validate_options(opts) @classmethod - def _get_all_transform(cls, opt): + def _get_all_transform(cls, opts): """Should only called after `_validate_data`.""" - all_transforms = set(opt.transforms) - for cname, corpus in opt.data.items(): + all_transforms = set(opts.transforms) + for cname, corpus in opts.tasks.items(): _transforms = set(corpus['transforms']) if len(_transforms) != 0: all_transforms.update(_transforms) - if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0: + if hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0: if not all_transforms.isdisjoint({'sentencepiece', 'bpe', 'onmt_tokenize'}): raise ValueError('lambda_align is not compatible with on-the-fly tokenization.') if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'denoising'}): raise ValueError('lambda_align is not compatible yet with potential token deletion/addition.') - opt._all_transform = all_transforms + opts._all_transform = all_transforms @classmethod - def _get_all_transform_translate(cls, opt): - opt._all_transform = opt.transforms + def _get_all_transform_translate(cls, opts): + opts._all_transform = opts.transforms @classmethod - def _validate_fields_opts(cls, opt, build_vocab_only=False): + def _validate_fields_opts(cls, opts, build_vocab_only=False): """Check options relate to vocab and fields.""" - for cname, corpus in opt.data.items(): + for cname, corpus in opts.tasks.items(): if cname != CorpusName.VALID and corpus["src_feats"] is not None: - assert opt.src_feats_vocab, "-src_feats_vocab is required if using source features." - if isinstance(opt.src_feats_vocab, str): - opt.src_feats_vocab = yaml.safe_load(opt.src_feats_vocab) + assert opts.src_feats_vocab, "-src_feats_vocab is required if using source features." + if isinstance(opts.src_feats_vocab, str): + opts.src_feats_vocab = yaml.safe_load(opts.src_feats_vocab) for feature in corpus["src_feats"].keys(): - assert feature in opt.src_feats_vocab, f"No vocab file set for feature {feature}" + assert feature in opts.src_feats_vocab, f"No vocab file set for feature {feature}" if build_vocab_only: - if not opt.share_vocab: - assert opt.tgt_vocab, "-tgt_vocab is required if not -share_vocab." + if not opts.share_vocab: + assert opts.tgt_vocab, "-tgt_vocab is required if not -share_vocab." return # validation when train: - for key, vocab in opt.src_vocab.items(): + for key, vocab in opts.src_vocab.items(): cls._validate_file(vocab, info=f'src vocab ({key})') - if not opt.share_vocab: - for key, vocab in opt.tgt_vocab.items(): + if not opts.share_vocab: + for key, vocab in opts.tgt_vocab.items(): cls._validate_file(vocab, info=f'tgt vocab ({key})') - # if opt.dump_fields or opt.dump_transforms: - if opt.dump_transforms: + # if opts.dump_fields or opts.dump_transforms: + if opts.dump_transforms: assert ( - opt.save_data + opts.save_data ), "-save_data should be set if set -dump_transforms." # Check embeddings stuff - if opt.both_embeddings is not None: + if opts.both_embeddings is not None: assert ( - opt.src_embeddings is None and opt.tgt_embeddings is None + opts.src_embeddings is None and opts.tgt_embeddings is None ), "You don't need -src_embeddings or -tgt_embeddings \ if -both_embeddings is set." - if any([opt.both_embeddings is not None, opt.src_embeddings is not None, opt.tgt_embeddings is not None]): - assert opt.embeddings_type is not None, "You need to specify an -embedding_type!" + if any([opts.both_embeddings is not None, opts.src_embeddings is not None, opts.tgt_embeddings is not None]): + assert opts.embeddings_type is not None, "You need to specify an -embedding_type!" assert ( - opt.save_data + opts.save_data ), "-save_data should be set if use pretrained embeddings." @classmethod - def _validate_language_model_compatibilities_opts(cls, opt): - if opt.model_task != ModelTask.LANGUAGE_MODEL: + def _validate_language_model_compatibilities_opts(cls, opts): + if opts.model_task != ModelTask.LANGUAGE_MODEL: return logger.info("encoder is not used for LM task") - assert opt.share_vocab and (opt.tgt_vocab is None), "vocab must be shared for LM task" + assert opts.share_vocab and (opts.tgt_vocab is None), "vocab must be shared for LM task" - assert opt.decoder_type == "transformer", "Only transformer decoder is supported for LM task" + assert opts.decoder_type == "transformer", "Only transformer decoder is supported for LM task" @classmethod - def validate_prepare_opts(cls, opt, build_vocab_only=False): + def validate_prepare_opts(cls, opts, build_vocab_only=False): """Validate all options relate to prepare (data/transform/vocab).""" - if opt.n_sample != 0: + if opts.n_sample != 0: assert ( - opt.save_data + opts.save_data ), "-save_data should be set if \ want save samples." - cls._validate_data(opt) - cls._get_all_transform(opt) - cls._validate_transforms_opts(opt) - cls._validate_fields_opts(opt, build_vocab_only=build_vocab_only) + cls._validate_data(opts) + cls._get_all_transform(opts) + cls._validate_transforms_opts(opts) + cls._validate_fields_opts(opts, build_vocab_only=build_vocab_only) @classmethod - def validate_model_opts(cls, opt): - cls._validate_language_model_compatibilities_opts(opt) + def validate_model_opts(cls, opts): + cls._validate_language_model_compatibilities_opts(opts) class ArgumentParser(cfargparse.ArgumentParser, DataOptsCheckerMixin): @@ -270,108 +270,103 @@ def defaults(cls, *args): return defaults @classmethod - def update_model_opts(cls, model_opt): - cls._validate_adapters(model_opt) - if model_opt.word_vec_size > 0: - model_opt.src_word_vec_size = model_opt.word_vec_size - model_opt.tgt_word_vec_size = model_opt.word_vec_size + def update_model_opts(cls, model_opts): + cls._validate_adapters(model_opts) + if model_opts.model_dim > 0: + model_opts.model_dim = model_opts.model_dim + model_opts.model_dim = model_opts.model_dim # Backward compatibility with "fix_word_vecs_*" opts - if hasattr(model_opt, 'fix_word_vecs_enc'): - model_opt.freeze_word_vecs_enc = model_opt.fix_word_vecs_enc - if hasattr(model_opt, 'fix_word_vecs_dec'): - model_opt.freeze_word_vecs_dec = model_opt.fix_word_vecs_dec + if hasattr(model_opts, 'fix_word_vecs_enc'): + model_opts.freeze_word_vecs_enc = model_opts.fix_word_vecs_enc + if hasattr(model_opts, 'fix_word_vecs_dec'): + model_opts.freeze_word_vecs_dec = model_opts.fix_word_vecs_dec - if model_opt.layers > 0: + if model_opts.layers > 0: raise Exception('--layers is deprecated') - if model_opt.rnn_size > 0: - model_opt.enc_rnn_size = model_opt.rnn_size - model_opt.dec_rnn_size = model_opt.rnn_size + model_opts.brnn = model_opts.encoder_type == "brnn" - model_opt.brnn = model_opt.encoder_type == "brnn" + if model_opts.copy_attn_type is None: + model_opts.copy_attn_type = model_opts.global_attention - if model_opt.copy_attn_type is None: - model_opt.copy_attn_type = model_opt.global_attention - - if model_opt.alignment_layer is None: - model_opt.alignment_layer = -2 - model_opt.lambda_align = 0.0 - model_opt.full_context_alignment = False + if model_opts.alignment_layer is None: + model_opts.alignment_layer = -2 + model_opts.lambda_align = 0.0 + model_opts.full_context_alignment = False @classmethod - def validate_model_opts(cls, model_opt): - assert model_opt.model_type in ["text"], "Unsupported model type %s" % model_opt.model_type + def validate_model_opts(cls, model_opts): + assert model_opts.model_type in ["text"], "Unsupported model type %s" % model_opts.model_type # encoder and decoder should be same sizes - same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size - assert same_size, "The encoder and decoder rnns must be the same size for now" + # assert same_size, "The encoder and decoder rnns must be the same size for now" - assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, "Using SRU requires -gpu_ranks set." - if model_opt.share_embeddings: - if model_opt.model_type != "text": + if model_opts.share_embeddings: + if model_opts.model_type != "text": raise AssertionError("--share_embeddings requires --model_type text.") - if model_opt.lambda_align > 0.0: - assert model_opt.decoder_type == 'transformer', "Only transformer is supported to joint learn alignment." + if model_opts.lambda_align > 0.0: + assert model_opts.decoder_type == 'transformer', "Only transformer is supported to joint learn alignment." assert ( - model_opt.alignment_layer < model_opt.dec_layers and model_opt.alignment_layer >= -model_opt.dec_layers + model_opts.alignment_layer < model_opts.dec_layers + and model_opts.alignment_layer >= -model_opts.dec_layers ), "N° alignment_layer should be smaller than number of layers." logger.info( "Joint learn alignment at layer [{}] " "with {} heads in full_context '{}'.".format( - model_opt.alignment_layer, model_opt.alignment_heads, model_opt.full_context_alignment + model_opts.alignment_layer, model_opts.alignment_heads, model_opts.full_context_alignment ) ) @classmethod def ckpt_model_opts(cls, ckpt_opt): - # Load default opt values, then overwrite with the opts in + # Load default opts values, then overwrite with the opts in # the checkpoint. That way, if there are new options added, # the defaults are used. - opt = cls.defaults(opts.model_opts) - opt.__dict__.update(ckpt_opt.__dict__) - return opt + the_opts = cls.defaults(opts.model_opts) + the_opts.__dict__.update(ckpt_opt.__dict__) + return the_opts @classmethod - def validate_train_opts(cls, opt): - if opt.epochs: + def validate_train_opts(cls, opts): + if opts.epochs: raise AssertionError("-epochs is deprecated please use -train_steps.") - if opt.truncated_decoder > 0 and max(opt.accum_count) > 1: + if opts.truncated_decoder > 0 and max(opts.accum_count) > 1: raise AssertionError("BPTT is not compatible with -accum > 1") - if opt.gpuid: + if opts.gpuid: raise AssertionError("gpuid is deprecated see world_size and gpu_ranks") - if torch.cuda.is_available() and not opt.gpu_ranks: + if torch.cuda.is_available() and not opts.gpu_ranks: logger.warn("You have a CUDA device, should run with -gpu_ranks") - if opt.world_size < len(opt.gpu_ranks): + if opts.world_size < len(opts.gpu_ranks): raise AssertionError("parameter counts of -gpu_ranks must be less or equal than -world_size.") - if len(opt.gpu_ranks) > 0 and opt.world_size == len(opt.gpu_ranks) and min(opt.gpu_ranks) > 0: + if len(opts.gpu_ranks) > 0 and opts.world_size == len(opts.gpu_ranks) and min(opts.gpu_ranks) > 0: raise AssertionError( "-gpu_ranks should have master(=0) rank unless -world_size is greater than len(gpu_ranks)." ) - assert len(opt.dropout) == len(opt.dropout_steps), "Number of dropout values must match accum_steps values" + assert len(opts.dropout) == len(opts.dropout_steps), "Number of dropout values must match accum_steps values" - assert len(opt.attention_dropout) == len( - opt.dropout_steps + assert len(opts.attention_dropout) == len( + opts.dropout_steps ), "Number of attention_dropout values must match accum_steps values" - assert len(opt.accum_count) == len( - opt.accum_steps + assert len(opts.accum_count) == len( + opts.accum_steps ), 'Number of accum_count values must match number of accum_steps' - if opt.update_vocab: - assert opt.train_from, "-update_vocab needs -train_from option" - assert opt.reset_optim in ['states', 'all'], '-update_vocab needs -reset_optim "states" or "all"' + if opts.update_vocab: + assert opts.train_from, "-update_vocab needs -train_from option" + assert opts.reset_optim in ['states', 'all'], '-update_vocab needs -reset_optim "states" or "all"' @classmethod - def validate_translate_opts(cls, opt): - opt.src_feats = eval(opt.src_feats) if opt.src_feats else {} + def validate_translate_opts(cls, opts): + opts.src_feats = eval(opts.src_feats) if opts.src_feats else {} @classmethod - def validate_translate_opts_dynamic(cls, opt): + def validate_translate_opts_dynamic(cls, opts): # It comes from training - # TODO: needs to be added as inference opt - opt.share_vocab = False + # TODO: needs to be added as inference opts + opts.share_vocab = False - opt.stack = yaml.safe_load(opt.stack) + opts.stack = yaml.safe_load(opts.stack) diff --git a/onmt/utils/report_manager.py b/mammoth/utils/report_manager.py similarity index 86% rename from onmt/utils/report_manager.py rename to mammoth/utils/report_manager.py index 35554a8a..822938d0 100644 --- a/onmt/utils/report_manager.py +++ b/mammoth/utils/report_manager.py @@ -2,28 +2,28 @@ import time from datetime import datetime -import onmt +import mammoth -from onmt.utils.logging import logger +from mammoth.utils.logging import logger -def build_report_manager(opt, node_rank, local_rank): - # Vanilla onmt has here an additional gpu_rank <= 0 +def build_report_manager(opts, node_rank, local_rank): + # Vanilla mammoth has here an additional gpu_rank <= 0 # which would cause only the first GPU of each node to log. # This change allows all GPUs to log. # Because tensorboard does not allow multiple processes writing into the same directory, # each device is treated as a separate run. - if opt.tensorboard: + if opts.tensorboard: from torch.utils.tensorboard import SummaryWriter - if not hasattr(opt, 'tensorboard_log_dir_dated'): - opt.tensorboard_log_dir_dated = opt.tensorboard_log_dir + datetime.now().strftime("/%b-%d_%H-%M-%S") + if not hasattr(opts, 'tensorboard_log_dir_dated'): + opts.tensorboard_log_dir_dated = opts.tensorboard_log_dir + datetime.now().strftime("/%b-%d_%H-%M-%S") - writer = SummaryWriter(f'{opt.tensorboard_log_dir_dated}-rank{node_rank}:{local_rank}', comment="Unmt") + writer = SummaryWriter(f'{opts.tensorboard_log_dir_dated}-rank{node_rank}:{local_rank}', comment="Unmt") else: writer = None - report_mgr = ReportMgr(opt.report_every, start_time=-1, tensorboard_writer=writer) + report_mgr = ReportMgr(opts.report_every, start_time=-1, tensorboard_writer=writer) return report_mgr @@ -73,9 +73,9 @@ def report_training(self, step, num_steps, learning_rate, patience, report_stats if step % self.report_every == 0: # if multigpu: # report_stats = \ - # onmt.utils.Statistics.all_gather_stats(report_stats) + # mammoth.utils.Statistics.all_gather_stats(report_stats) self._report_training(step, num_steps, learning_rate, patience, report_stats) - return onmt.utils.Statistics() + return mammoth.utils.Statistics() else: return report_stats @@ -125,7 +125,7 @@ def _report_training(self, step, num_steps, learning_rate, patience, report_stat report_stats.output(step, num_steps, learning_rate, self.start_time) self.maybe_log_tensorboard(report_stats, "progress", learning_rate, patience, step) - report_stats = onmt.utils.Statistics() + report_stats = mammoth.utils.Statistics() return report_stats diff --git a/onmt/utils/statistics.py b/mammoth/utils/statistics.py similarity index 98% rename from onmt/utils/statistics.py rename to mammoth/utils/statistics.py index 17f3e151..ff8292a5 100644 --- a/onmt/utils/statistics.py +++ b/mammoth/utils/statistics.py @@ -7,7 +7,7 @@ from collections import Counter from torch.linalg import norm -from onmt.utils.logging import logger +from mammoth.utils.logging import logger class Statistics(object): @@ -65,7 +65,7 @@ def all_gather_stats_list(stat_list, max_size=4096): our_stats(list([`Statistics`])): list of updated stats """ from torch.distributed import get_rank - from onmt.utils.distributed import all_gather_list + from mammoth.distributed import all_gather_list # Get a list of world_size lists with len(stat_list) Statistics objects all_stats = all_gather_list(stat_list, max_size=max_size) diff --git a/onmt/__init__.py b/onmt/__init__.py deleted file mode 100644 index 78a71d74..00000000 --- a/onmt/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -""" Main entry point of the ONMT library """ -import onmt.inputters -import onmt.encoders -import onmt.decoders -import onmt.models -import onmt.utils -import onmt.modules -import onmt.opts -from onmt.trainer import Trainer -import sys -import onmt.utils.optimizers - -onmt.utils.optimizers.Optim = onmt.utils.optimizers.Optimizer -sys.modules["onmt.Optim"] = onmt.utils.optimizers - -__all__ = [ - onmt.inputters, - onmt.encoders, - onmt.decoders, - onmt.models, - onmt.utils, - onmt.modules, - onmt.opts, - "Trainer" -] - -__version__ = "2.2.0" diff --git a/onmt/bin/translate.py b/onmt/bin/translate.py deleted file mode 100644 index 99ab0a82..00000000 --- a/onmt/bin/translate.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -from onmt.utils.logging import init_logger -from onmt.utils.misc import split_corpus -from onmt.translate.translator import build_translator -# from onmt.inputters.text_dataset import InferenceDataReader -from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe - -import onmt.opts as opts -from onmt.utils.distributed import TaskSpecs -from onmt.utils.parse import ArgumentParser - - -def translate(opt): - ArgumentParser.validate_translate_opts(opt) - ArgumentParser._get_all_transform_translate(opt) - ArgumentParser._validate_transforms_opts(opt) - ArgumentParser.validate_translate_opts_dynamic(opt) - logger = init_logger(opt.log_file) - - encoder_adapter_ids = set() - for layer_stack_idx, stack in enumerate(opt.stack['encoder']): - if 'adapters' in stack: - for group_id, sub_id in stack['adapters']: - encoder_adapter_ids.add((layer_stack_idx, group_id, sub_id)) - decoder_adapter_ids = set() - for layer_stack_idx, stack in enumerate(opt.stack['decoder']): - if 'adapters' in stack: - for group_id, sub_id in stack['adapters']: - decoder_adapter_ids.add((layer_stack_idx, group_id, sub_id)) - - logger.info( - 'It is ok that src_vocab and tgt_vocab are None here. ' - 'The vocabs are separately loaded in model_builder.' - ) - task = TaskSpecs( - node_rank=None, - local_rank=None, - src_lang=opt.src_lang, - tgt_lang=opt.tgt_lang, - encoder_id=[stack['id'] for stack in opt.stack['encoder']], - decoder_id=[stack['id'] for stack in opt.stack['decoder']], - corpus_id='trans', - weight=1, - corpus_opt=dict(), - src_vocab=None, - tgt_vocab=None, - encoder_adapter_ids=encoder_adapter_ids, - decoder_adapter_ids=decoder_adapter_ids, - ) - - translator = build_translator(opt, task, logger=logger, report_score=True) - - # data_reader = InferenceDataReader(opt.src, opt.tgt, opt.src_feats) - src_shards = split_corpus(opt.src, opt.shard_size) - tgt_shards = split_corpus(opt.tgt, opt.shard_size) - features_shards = [] - features_names = [] - for feat_name, feat_path in opt.src_feats.items(): - features_shards.append(split_corpus(feat_path, opt.shard_size)) - features_names.append(feat_name) - shard_pairs = zip(src_shards, tgt_shards, *features_shards) - - # Build transforms - transforms_cls = get_transforms_cls(opt._all_transform) - transforms = make_transforms(opt, transforms_cls, translator.vocabs, task=task) - data_transform = [ - transforms[name] for name in opt.transforms if name in transforms - ] - transform = TransformPipe.build_from(data_transform) - - for i, (src_shard, tgt_shard, *feats_shard) in enumerate(shard_pairs): - logger.info("Translating shard %d." % i) - translator.translate_dynamic( - src=src_shard, - transform=transform, - # src_feats=feats_shard, # TODO: put me back in - tgt=tgt_shard, - batch_size=opt.batch_size, - batch_type=opt.batch_type, - attn_debug=opt.attn_debug, - align_debug=opt.align_debug - ) - - -def _get_parser(): - parser = ArgumentParser(description='translate.py') - - opts.config_opts(parser) - opts.translate_opts(parser, dynamic=True) - opts.build_bilingual_model(parser) - return parser - - -def main(): - parser = _get_parser() - - opt = parser.parse_args() - translate(opt) - - -if __name__ == "__main__": - main() diff --git a/onmt/decoders/__init__.py b/onmt/decoders/__init__.py deleted file mode 100644 index ab6262e3..00000000 --- a/onmt/decoders/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Module defining decoders.""" -from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, StdRNNDecoder -from onmt.decoders.transformer import TransformerDecoder -from onmt.decoders.cnn_decoder import CNNDecoder - - -str2dec = { - "rnn": StdRNNDecoder, - "ifrnn": InputFeedRNNDecoder, - "cnn": CNNDecoder, - "transformer": TransformerDecoder, -} - -__all__ = [ - "DecoderBase", - "TransformerDecoder", - "StdRNNDecoder", - "CNNDecoder", - "InputFeedRNNDecoder", - "str2dec", -] diff --git a/onmt/decoders/cnn_decoder.py b/onmt/decoders/cnn_decoder.py deleted file mode 100644 index 5a82f261..00000000 --- a/onmt/decoders/cnn_decoder.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Implementation of the CNN Decoder part of -"Convolutional Sequence to Sequence Learning" -""" -import torch -import torch.nn as nn - -from onmt.modules import ConvMultiStepAttention, GlobalAttention -from onmt.utils.cnn_factory import shape_transform, GatedConv -from onmt.decoders.decoder import DecoderBase - -SCALE_WEIGHT = 0.5**0.5 - - -class CNNDecoder(DecoderBase): - """Decoder based on "Convolutional Sequence to Sequence Learning" - :cite:`DBLP:journals/corr/GehringAGYD17`. - - Consists of residual convolutional layers, with ConvMultiStepAttention. - """ - - def __init__( - self, num_layers, hidden_size, attn_type, copy_attn, cnn_kernel_width, dropout, embeddings, copy_attn_type - ): - super(CNNDecoder, self).__init__() - - self.cnn_kernel_width = cnn_kernel_width - self.embeddings = embeddings - - # Decoder State - self.state = {} - - input_size = self.embeddings.embedding_size - self.linear = nn.Linear(input_size, hidden_size) - self.conv_layers = nn.ModuleList( - [GatedConv(hidden_size, cnn_kernel_width, dropout, True) for i in range(num_layers)] - ) - self.attn_layers = nn.ModuleList([ConvMultiStepAttention(hidden_size) for i in range(num_layers)]) - - # CNNDecoder has its own attention mechanism. - # Set up a separate copy attention layer if needed. - assert not copy_attn, "Copy mechanism not yet tested in conv2conv" - if copy_attn: - self.copy_attn = GlobalAttention(hidden_size, attn_type=copy_attn_type) - else: - self.copy_attn = None - - @classmethod - def from_opt(cls, opt, embeddings): - """Alternate constructor.""" - return cls( - opt.dec_layers, - opt.dec_rnn_size, - opt.global_attention, - opt.copy_attn, - opt.cnn_kernel_width, - opt.dropout[0] if type(opt.dropout) is list else opt.dropout, - embeddings, - opt.copy_attn_type, - ) - - def init_state(self, _, memory_bank, enc_hidden): - """Init decoder state.""" - self.state["src"] = (memory_bank + enc_hidden) * SCALE_WEIGHT - self.state["previous_input"] = None - - def map_state(self, fn): - self.state["src"] = fn(self.state["src"], 1) - if self.state["previous_input"] is not None: - self.state["previous_input"] = fn(self.state["previous_input"], 1) - - def detach_state(self): - self.state["previous_input"] = self.state["previous_input"].detach() - - def forward(self, tgt, memory_bank, step=None, **kwargs): - """See :obj:`onmt.modules.RNNDecoderBase.forward()`""" - - if self.state["previous_input"] is not None: - tgt = torch.cat([self.state["previous_input"], tgt], 0) - - dec_outs = [] - attns = {"std": []} - if self.copy_attn is not None: - attns["copy"] = [] - - emb = self.embeddings(tgt) - assert emb.dim() == 3 # len x batch x embedding_dim - - tgt_emb = emb.transpose(0, 1).contiguous() - # The output of CNNEncoder. - src_memory_bank_t = memory_bank.transpose(0, 1).contiguous() - # The combination of output of CNNEncoder and source embeddings. - src_memory_bank_c = self.state["src"].transpose(0, 1).contiguous() - - emb_reshape = tgt_emb.contiguous().view(tgt_emb.size(0) * tgt_emb.size(1), -1) - linear_out = self.linear(emb_reshape) - x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1) - x = shape_transform(x) - - pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1) - - pad = pad.type_as(x) - base_target_emb = x - - for conv, attention in zip(self.conv_layers, self.attn_layers): - new_target_input = torch.cat([pad, x], 2) - out = conv(new_target_input) - c, attn = attention(base_target_emb, out, src_memory_bank_t, src_memory_bank_c) - x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT - output = x.squeeze(3).transpose(1, 2) - - # Process the result and update the attentions. - dec_outs = output.transpose(0, 1).contiguous() - if self.state["previous_input"] is not None: - dec_outs = dec_outs[self.state["previous_input"].size(0):] - attn = attn[:, self.state["previous_input"].size(0):].squeeze() - attn = torch.stack([attn]) - attns["std"] = attn - if self.copy_attn is not None: - attns["copy"] = attn - - # Update the state. - self.state["previous_input"] = tgt - # TODO change the way attns is returned dict => list or tuple (onnx) - return dec_outs, attns - - def update_dropout(self, dropout): - for layer in self.conv_layers: - layer.dropout.p = dropout diff --git a/onmt/decoders/decoder.py b/onmt/decoders/decoder.py deleted file mode 100644 index b5bdd516..00000000 --- a/onmt/decoders/decoder.py +++ /dev/null @@ -1,428 +0,0 @@ -import torch -import torch.nn as nn - -from onmt.models.stacked_rnn import StackedLSTM, StackedGRU -from onmt.modules import context_gate_factory, GlobalAttention -from onmt.utils.rnn_factory import rnn_factory - -from onmt.utils.misc import aeq - - -class DecoderBase(nn.Module): - """Abstract class for decoders. - - Args: - attentional (bool): The decoder returns non-empty attention. - """ - - def __init__(self, attentional=True): - super(DecoderBase, self).__init__() - self.attentional = attentional - - @classmethod - def from_opt(cls, opt, embeddings): - """Alternate constructor. - - Subclasses should override this method. - """ - - raise NotImplementedError - - -class RNNDecoderBase(DecoderBase): - """Base recurrent attention-based decoder class. - - Specifies the interface used by different decoder types - and required by :class:`~onmt.models.NMTModel`. - - - .. mermaid:: - - graph BT - A[Input] - subgraph RNN - C[Pos 1] - D[Pos 2] - E[Pos N] - end - G[Decoder State] - H[Decoder State] - I[Outputs] - F[memory_bank] - A--emb-->C - A--emb-->D - A--emb-->E - H-->C - C-- attn --- F - D-- attn --- F - E-- attn --- F - C-->I - D-->I - E-->I - E-->G - F---I - - Args: - rnn_type (str): - style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] - bidirectional_encoder (bool) : use with a bidirectional encoder - num_layers (int) : number of stacked layers - hidden_size (int) : hidden size of each layer - attn_type (str) : see :class:`~onmt.modules.GlobalAttention` - attn_func (str) : see :class:`~onmt.modules.GlobalAttention` - coverage_attn (str): see :class:`~onmt.modules.GlobalAttention` - context_gate (str): see :class:`~onmt.modules.ContextGate` - copy_attn (bool): setup a separate copy attention mechanism - dropout (float) : dropout value for :class:`torch.nn.Dropout` - embeddings (onmt.modules.Embeddings): embedding module to use - reuse_copy_attn (bool): reuse the attention for copying - copy_attn_type (str): The copy attention style. See - :class:`~onmt.modules.GlobalAttention`. - """ - - def __init__( - self, - rnn_type, - bidirectional_encoder, - num_layers, - hidden_size, - attn_type="general", - attn_func="softmax", - coverage_attn=False, - context_gate=None, - copy_attn=False, - dropout=0.0, - embeddings=None, - reuse_copy_attn=False, - copy_attn_type="general", - ): - super(RNNDecoderBase, self).__init__(attentional=attn_type != "none" and attn_type is not None) - - self.bidirectional_encoder = bidirectional_encoder - self.num_layers = num_layers - self.hidden_size = hidden_size - self.embeddings = embeddings - self.dropout = nn.Dropout(dropout) - - # Decoder state - self.state = {} - - # Build the RNN. - self.rnn = self._build_rnn( - rnn_type, input_size=self._input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout - ) - - # Set up the context gate. - self.context_gate = None - if context_gate is not None: - self.context_gate = context_gate_factory( - context_gate, self._input_size, hidden_size, hidden_size, hidden_size - ) - - # Set up the standard attention. - self._coverage = coverage_attn - if not self.attentional: - if self._coverage: - raise ValueError("Cannot use coverage term with no attention.") - self.attn = None - else: - self.attn = GlobalAttention(hidden_size, coverage=coverage_attn, attn_type=attn_type, attn_func=attn_func) - - if copy_attn and not reuse_copy_attn: - if copy_attn_type == "none" or copy_attn_type is None: - raise ValueError("Cannot use copy_attn with copy_attn_type none") - self.copy_attn = GlobalAttention(hidden_size, attn_type=copy_attn_type, attn_func=attn_func) - else: - self.copy_attn = None - - self._reuse_copy_attn = reuse_copy_attn and copy_attn - if self._reuse_copy_attn and not self.attentional: - raise ValueError("Cannot reuse copy attention with no attention.") - - @classmethod - def from_opt(cls, opt, embeddings): - """Alternate constructor.""" - return cls( - opt.rnn_type, - opt.brnn, - opt.dec_layers, - opt.dec_rnn_size, - opt.global_attention, - opt.global_attention_function, - opt.coverage_attn, - opt.context_gate, - opt.copy_attn, - opt.dropout[0] if type(opt.dropout) is list else opt.dropout, - embeddings, - opt.reuse_copy_attn, - opt.copy_attn_type, - ) - - def init_state(self, src, memory_bank, encoder_final): - """Initialize decoder state with last state of the encoder.""" - - def _fix_enc_hidden(hidden): - # The encoder hidden is (layers*directions) x batch x dim. - # We need to convert it to layers x batch x (directions*dim). - if self.bidirectional_encoder: - hidden = torch.cat( - [hidden[0:hidden.size(0):2], hidden[1:hidden.size(0):2]], 2 - ) - return hidden - - if isinstance(encoder_final, tuple): # LSTM - self.state["hidden"] = tuple(_fix_enc_hidden(enc_hid) for enc_hid in encoder_final) - else: # GRU - self.state["hidden"] = (_fix_enc_hidden(encoder_final),) - - # Init the input feed. - batch_size = self.state["hidden"][0].size(1) - h_size = (batch_size, self.hidden_size) - self.state["input_feed"] = self.state["hidden"][0].data.new(*h_size).zero_().unsqueeze(0) - self.state["coverage"] = None - - def map_state(self, fn): - self.state["hidden"] = tuple(fn(h, 1) for h in self.state["hidden"]) - self.state["input_feed"] = fn(self.state["input_feed"], 1) - if self._coverage and self.state["coverage"] is not None: - self.state["coverage"] = fn(self.state["coverage"], 1) - - def detach_state(self): - self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"]) - self.state["input_feed"] = self.state["input_feed"].detach() - - def forward(self, tgt, memory_bank, memory_lengths=None, step=None, **kwargs): - """ - Args: - tgt (LongTensor): sequences of padded tokens - ``(tgt_len, batch, nfeats)``. - memory_bank (FloatTensor): vectors from the encoder - ``(src_len, batch, hidden)``. - memory_lengths (LongTensor): the padded source lengths - ``(batch,)``. - - Returns: - (FloatTensor, dict[str, FloatTensor]): - - * dec_outs: output from the decoder (after attn) - ``(tgt_len, batch, hidden)``. - * attns: distribution over src at each tgt - ``(tgt_len, batch, src_len)``. - """ - - dec_state, dec_outs, attns = self._run_forward_pass(tgt, memory_bank, memory_lengths=memory_lengths) - - # Update the state with the result. - if not isinstance(dec_state, tuple): - dec_state = (dec_state,) - self.state["hidden"] = dec_state - self.state["input_feed"] = dec_outs[-1].unsqueeze(0) - self.state["coverage"] = None - if "coverage" in attns: - self.state["coverage"] = attns["coverage"][-1].unsqueeze(0) - - # Concatenates sequence of tensors along a new dimension. - # NOTE: v0.3 to 0.4: dec_outs / attns[*] may not be list - # (in particular in case of SRU) it was not raising error in 0.3 - # since stack(Variable) was allowed. - # In 0.4, SRU returns a tensor that shouldn't be stacke - if type(dec_outs) == list: - dec_outs = torch.stack(dec_outs) - - for k in attns: - if type(attns[k]) == list: - attns[k] = torch.stack(attns[k]) - return dec_outs, attns - - def update_dropout(self, dropout): - self.dropout.p = dropout - self.embeddings.update_dropout(dropout) - - -class StdRNNDecoder(RNNDecoderBase): - """Standard fully batched RNN decoder with attention. - - Faster implementation, uses CuDNN for implementation. - See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options. - - - Based around the approach from - "Neural Machine Translation By Jointly Learning To Align and Translate" - :cite:`Bahdanau2015` - - - Implemented without input_feeding and currently with no `coverage_attn` - or `copy_attn` support. - """ - - def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): - """ - Private helper for running the specific RNN forward pass. - Must be overriden by all subclasses. - - Args: - tgt (LongTensor): a sequence of input tokens tensors - ``(len, batch, nfeats)``. - memory_bank (FloatTensor): output(tensor sequence) from the - encoder RNN of size ``(src_len, batch, hidden_size)``. - memory_lengths (LongTensor): the source memory_bank lengths. - - Returns: - (Tensor, List[FloatTensor], Dict[str, List[FloatTensor]): - - * dec_state: final hidden state from the decoder. - * dec_outs: an array of output of every time - step from the decoder. - * attns: a dictionary of different - type of attention Tensor array of every time - step from the decoder. - """ - - assert self.copy_attn is None # TODO, no support yet. - assert not self._coverage # TODO, no support yet. - - attns = {} - emb = self.embeddings(tgt) - - if isinstance(self.rnn, nn.GRU): - rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0]) - else: - rnn_output, dec_state = self.rnn(emb, self.state["hidden"]) - - # Check - tgt_len, tgt_batch, _ = tgt.size() - output_len, output_batch, _ = rnn_output.size() - aeq(tgt_len, output_len) - aeq(tgt_batch, output_batch) - - # Calculate the attention. - if not self.attentional: - dec_outs = rnn_output - else: - dec_outs, p_attn = self.attn( - rnn_output.transpose(0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths - ) - attns["std"] = p_attn - - # Calculate the context gate. - if self.context_gate is not None: - dec_outs = self.context_gate( - emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), dec_outs.view(-1, dec_outs.size(2)) - ) - dec_outs = dec_outs.view(tgt_len, tgt_batch, self.hidden_size) - - dec_outs = self.dropout(dec_outs) - return dec_state, dec_outs, attns - - def _build_rnn(self, rnn_type, **kwargs): - rnn, _ = rnn_factory(rnn_type, **kwargs) - return rnn - - @property - def _input_size(self): - return self.embeddings.embedding_size - - -class InputFeedRNNDecoder(RNNDecoderBase): - """Input feeding based decoder. - - See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options. - - Based around the input feeding approach from - "Effective Approaches to Attention-based Neural Machine Translation" - :cite:`Luong2015` - - - .. mermaid:: - - graph BT - A[Input n-1] - AB[Input n] - subgraph RNN - E[Pos n-1] - F[Pos n] - E --> F - end - G[Encoder] - H[memory_bank n-1] - A --> E - AB --> F - E --> H - G --> H - """ - - def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): - """ - See StdRNNDecoder._run_forward_pass() for description - of arguments and return values. - """ - # Additional args check. - input_feed = self.state["input_feed"].squeeze(0) - input_feed_batch, _ = input_feed.size() - _, tgt_batch, _ = tgt.size() - aeq(tgt_batch, input_feed_batch) - # END Additional args check. - - dec_outs = [] - attns = {} - if self.attn is not None: - attns["std"] = [] - if self.copy_attn is not None or self._reuse_copy_attn: - attns["copy"] = [] - if self._coverage: - attns["coverage"] = [] - - emb = self.embeddings(tgt) - assert emb.dim() == 3 # len x batch x embedding_dim - - dec_state = self.state["hidden"] - coverage = self.state["coverage"].squeeze(0) if self.state["coverage"] is not None else None - - # Input feed concatenates hidden state with - # input at every time step. - for emb_t in emb.split(1): - decoder_input = torch.cat([emb_t.squeeze(0), input_feed], 1) - rnn_output, dec_state = self.rnn(decoder_input, dec_state) - if self.attentional: - decoder_output, p_attn = self.attn( - rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths - ) - attns["std"].append(p_attn) - else: - decoder_output = rnn_output - if self.context_gate is not None: - # TODO: context gate should be employed - # instead of second RNN transform. - decoder_output = self.context_gate(decoder_input, rnn_output, decoder_output) - decoder_output = self.dropout(decoder_output) - input_feed = decoder_output - - dec_outs += [decoder_output] - - # Update the coverage attention. - if self._coverage: - coverage = p_attn if coverage is None else p_attn + coverage - attns["coverage"] += [coverage] - - if self.copy_attn is not None: - _, copy_attn = self.copy_attn(decoder_output, memory_bank.transpose(0, 1)) - attns["copy"] += [copy_attn] - elif self._reuse_copy_attn: - attns["copy"] = attns["std"] - - return dec_state, dec_outs, attns - - def _build_rnn(self, rnn_type, input_size, hidden_size, num_layers, dropout): - assert rnn_type != "SRU", "SRU doesn't support input feed! Please set -input_feed 0!" - stacked_cell = StackedLSTM if rnn_type == "LSTM" else StackedGRU - return stacked_cell(num_layers, input_size, hidden_size, dropout) - - @property - def _input_size(self): - """Using input feed by concatenating input with attention vectors.""" - return self.embeddings.embedding_size + self.hidden_size - - def update_dropout(self, dropout): - self.dropout.p = dropout - self.rnn.dropout.p = dropout - self.embeddings.update_dropout(dropout) diff --git a/onmt/encoders/__init__.py b/onmt/encoders/__init__.py deleted file mode 100644 index 7885a187..00000000 --- a/onmt/encoders/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Module defining encoders.""" -from onmt.encoders.encoder import EncoderBase -from onmt.encoders.transformer import TransformerEncoder -from onmt.encoders.ggnn_encoder import GGNNEncoder -from onmt.encoders.rnn_encoder import RNNEncoder -from onmt.encoders.cnn_encoder import CNNEncoder -from onmt.encoders.mean_encoder import MeanEncoder - - -str2enc = { - "ggnn": GGNNEncoder, - "rnn": RNNEncoder, - "brnn": RNNEncoder, - "cnn": CNNEncoder, - "transformer": TransformerEncoder, - "mean": MeanEncoder, -} - -__all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", "MeanEncoder", "str2enc"] diff --git a/onmt/encoders/cnn_encoder.py b/onmt/encoders/cnn_encoder.py deleted file mode 100644 index ffa22a4f..00000000 --- a/onmt/encoders/cnn_encoder.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -Implementation of "Convolutional Sequence to Sequence Learning" -""" -import torch.nn as nn - -from onmt.encoders.encoder import EncoderBase -from onmt.utils.cnn_factory import shape_transform, StackedCNN - -SCALE_WEIGHT = 0.5**0.5 - - -class CNNEncoder(EncoderBase): - """Encoder based on "Convolutional Sequence to Sequence Learning" - :cite:`DBLP:journals/corr/GehringAGYD17`. - """ - - def __init__(self, num_layers, hidden_size, cnn_kernel_width, dropout, embeddings): - super(CNNEncoder, self).__init__() - - self.embeddings = embeddings - input_size = embeddings.embedding_size - self.linear = nn.Linear(input_size, hidden_size) - self.cnn = StackedCNN(num_layers, hidden_size, cnn_kernel_width, dropout) - - @classmethod - def from_opt(cls, opt, embeddings): - """Alternate constructor.""" - return cls( - opt.enc_layers, - opt.enc_rnn_size, - opt.cnn_kernel_width, - opt.dropout[0] if type(opt.dropout) is list else opt.dropout, - embeddings, - ) - - def forward(self, input, lengths=None, hidden=None): - """See :class:`onmt.modules.EncoderBase.forward()`""" - self._check_args(input, lengths, hidden) - - emb = self.embeddings(input) - # s_len, batch, emb_dim = emb.size() - - emb = emb.transpose(0, 1).contiguous() - emb_reshape = emb.view(emb.size(0) * emb.size(1), -1) - emb_remap = self.linear(emb_reshape) - emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1) - emb_remap = shape_transform(emb_remap) - out = self.cnn(emb_remap) - - return emb_remap.squeeze(3).transpose(0, 1).contiguous(), out.squeeze(3).transpose(0, 1).contiguous(), lengths - - def update_dropout(self, dropout): - self.cnn.dropout.p = dropout diff --git a/onmt/encoders/ggnn_encoder.py b/onmt/encoders/ggnn_encoder.py deleted file mode 100644 index 209b00ab..00000000 --- a/onmt/encoders/ggnn_encoder.py +++ /dev/null @@ -1,311 +0,0 @@ -"""Define GGNN-based encoders.""" -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from onmt.encoders.encoder import EncoderBase - - -class GGNNAttrProxy(object): - """ - Translates index lookups into attribute lookups. - To implement some trick which able to use list of nn.Module in a nn.Module - see https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219/2 - """ - - def __init__(self, module, prefix): - self.module = module - self.prefix = prefix - - def __getitem__(self, i): - return getattr(self.module, self.prefix + str(i)) - - -class GGNNPropogator(nn.Module): - """ - Gated Propogator for GGNN - Using LSTM gating mechanism - """ - - def __init__(self, state_dim, n_node, n_edge_types): - super(GGNNPropogator, self).__init__() - - self.n_node = n_node - self.n_edge_types = n_edge_types - - self.reset_gate = nn.Sequential(nn.Linear(state_dim * 3, state_dim), nn.Sigmoid()) - self.update_gate = nn.Sequential(nn.Linear(state_dim * 3, state_dim), nn.Sigmoid()) - self.tansform = nn.Sequential(nn.Linear(state_dim * 3, state_dim), nn.LeakyReLU()) - - def forward(self, state_in, state_out, state_cur, edges, nodes): - edges_in = edges[:, :, : nodes * self.n_edge_types] - edges_out = edges[:, :, nodes * self.n_edge_types:] - - a_in = torch.bmm(edges_in, state_in) - a_out = torch.bmm(edges_out, state_out) - a = torch.cat((a_in, a_out, state_cur), 2) - - r = self.reset_gate(a) - z = self.update_gate(a) - joined_input = torch.cat((a_in, a_out, r * state_cur), 2) - h_hat = self.tansform(joined_input) - - output = (1 - z) * state_cur + z * h_hat - - return output - - -class GGNNEncoder(EncoderBase): - """A gated graph neural network configured as an encoder. - Based on github.com/JamesChuanggg/ggnn.pytorch.git, - which is based on the paper "Gated Graph Sequence Neural Networks" - by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel. - - Args: - rnn_type (str): - style of recurrent unit to use, one of [LSTM] - src_ggnn_size (int) : Size of token-to-node embedding input - src_word_vec_size (int) : Size of token-to-node embedding output - state_dim (int) : Number of state dimensions in nodes - n_edge_types (int) : Number of edge types - bidir_edges (bool): True if reverse edges should be autocreated - n_node (int) : Max nodes in graph - bridge_extra_node (bool): True indicates only 1st extra node - (after token listing) should be used for decoder init. - n_steps (int): Steps to advance graph encoder for stabilization - src_vocab (int): Path to source vocabulary.(The ggnn uses src_vocab - during training because the graph is built using edge information - which requires parsing the input sequence.) - """ - - def __init__( - self, - rnn_type, - src_word_vec_size, - src_ggnn_size, - state_dim, - bidir_edges, - n_edge_types, - n_node, - bridge_extra_node, - n_steps, - src_vocab, - ): - super(GGNNEncoder, self).__init__() - - self.src_word_vec_size = src_word_vec_size - self.src_ggnn_size = src_ggnn_size - self.state_dim = state_dim - self.n_edge_types = n_edge_types - self.n_node = n_node - self.n_steps = n_steps - self.bidir_edges = bidir_edges - self.bridge_extra_node = bridge_extra_node - - for i in range(self.n_edge_types): - # incoming and outgoing edge embedding - in_fc = nn.Linear(self.state_dim, self.state_dim) - out_fc = nn.Linear(self.state_dim, self.state_dim) - self.add_module("in_{}".format(i), in_fc) - self.add_module("out_{}".format(i), out_fc) - - self.in_fcs = GGNNAttrProxy(self, "in_") - self.out_fcs = GGNNAttrProxy(self, "out_") - - # Find vocab data for tree builting - f = open(src_vocab, "r") - idx = 0 - self.COMMA = -1 - self.DELIMITER = -1 - self.idx2num = [] - found_n_minus_one = False - for ln in f: - ln = ln.strip('\n') - ln = ln.split('\t')[0] - if idx == 0 and ln != "": - idx += 1 - self.idx2num.append(-1) - if idx == 1 and ln != "": - idx += 1 - self.idx2num.append(-1) - if ln == ",": - self.COMMA = idx - if ln == "": - self.DELIMITER = idx - if ln.isdigit(): - self.idx2num.append(int(ln)) - if int(ln) == n_node - 1: - found_n_minus_one = True - else: - self.idx2num.append(-1) - idx += 1 - - assert self.COMMA >= 0, "GGNN src_vocab must include ',' character" - assert self.DELIMITER >= 0, "GGNN src_vocab must include token" - assert found_n_minus_one, "GGNN src_vocab must include node numbers for edge connections" - - # Propogation Model - self.propogator = GGNNPropogator(self.state_dim, self.n_node, self.n_edge_types) - - self._initialization() - - # Initialize the bridge layer - self._initialize_bridge(rnn_type, self.state_dim, 1) - - # Token embedding - if src_ggnn_size > 0: - self.embed = nn.Sequential(nn.Linear(src_ggnn_size, src_word_vec_size), nn.LeakyReLU()) - assert self.src_ggnn_size >= self.DELIMITER, "Embedding input must be larger than vocabulary" - assert self.src_word_vec_size < self.state_dim, "Embedding size must be smaller than state_dim" - else: - assert self.DELIMITER < self.state_dim, "Vocabulary too large, consider -src_ggnn_size" - - @classmethod - def from_opt(cls, opt, embeddings): - """Alternate constructor.""" - return cls( - opt.rnn_type, - opt.src_word_vec_size, - opt.src_ggnn_size, - opt.state_dim, - opt.bidir_edges, - opt.n_edge_types, - opt.n_node, - opt.bridge_extra_node, - opt.n_steps, - opt.src_vocab, - ) - - def _initialization(self): - for m in self.modules(): - if isinstance(m, nn.Linear): - m.weight.data.normal_(0.0, 0.02) - m.bias.data.fill_(0) - - def forward(self, src, lengths=None): - """See :func:`EncoderBase.forward()`""" - self._check_args(src, lengths) - nodes = self.n_node - batch_size = src.size()[1] - first_extra = np.zeros(batch_size, dtype=np.int32) - token_onehot = np.zeros( - (batch_size, nodes, self.src_ggnn_size if self.src_ggnn_size > 0 else self.state_dim), dtype=np.int32 - ) - edges = np.zeros((batch_size, nodes, nodes * self.n_edge_types * 2), dtype=np.int32) - npsrc = src[:, :, 0].cpu().data.numpy().astype(np.int32) - - # Initialize graph using formatted input sequence - for i in range(batch_size): - tokens_done = False - # Number of flagged nodes defines node count for this sample - # (Nodes can have no flags on them, but must be in 'flags' list). - flag_node = 0 - flags_done = False - edge = 0 - source_node = -1 - for j in range(len(npsrc)): - token = npsrc[j][i] - if not tokens_done: - if token == self.DELIMITER: - tokens_done = True - first_extra[i] = j - else: - token_onehot[i][j][token] = 1 - elif token == self.DELIMITER: - flag_node += 1 - flags_done = True - assert flag_node <= nodes, "Too many nodes with flags" - elif not flags_done: - # The total number of integers in the vocab should allow - # for all features and edges to be defined. - if token == self.COMMA: - flag_node = 0 - else: - num = self.idx2num[token] - if num >= 0: - token_onehot[i][flag_node][num + self.DELIMITER] = 1 - flag_node += 1 - elif token == self.COMMA: - edge += 1 - assert source_node == -1, f'Error in graph edge input: {source_node} unpaired' - assert edge < self.n_edge_types, "Too many edge types in input" - else: - num = self.idx2num[token] - if source_node < 0: - source_node = num - else: - edges[i][source_node][num + nodes * edge] = 1 - if self.bidir_edges: - edges[i][num][nodes * (edge + self.n_edge_types) + source_node] = 1 - source_node = -1 - - token_onehot = torch.from_numpy(token_onehot).float().to(src.device) - if self.src_ggnn_size > 0: - token_embed = self.embed(token_onehot) - prop_state = torch.cat( - ( - token_embed, - torch.zeros((batch_size, nodes, self.state_dim - self.src_word_vec_size)).float().to(src.device), - ), - 2, - ) - else: - prop_state = token_onehot - edges = torch.from_numpy(edges).float().to(src.device) - - for i_step in range(self.n_steps): - in_states = [] - out_states = [] - for i in range(self.n_edge_types): - in_states.append(self.in_fcs[i](prop_state)) - out_states.append(self.out_fcs[i](prop_state)) - in_states = torch.stack(in_states).transpose(0, 1).contiguous() - in_states = in_states.view(-1, nodes * self.n_edge_types, self.state_dim) - out_states = torch.stack(out_states).transpose(0, 1).contiguous() - out_states = out_states.view(-1, nodes * self.n_edge_types, self.state_dim) - - prop_state = self.propogator(in_states, out_states, prop_state, edges, nodes) - - prop_state = prop_state.transpose(0, 1) - if self.bridge_extra_node: - # Use first extra node as only source for decoder init - join_state = prop_state[first_extra, torch.arange(batch_size)] - else: - # Average all nodes to get bridge input - join_state = prop_state.mean(0) - join_state = torch.stack((join_state, join_state, join_state, join_state)) - join_state = (join_state, join_state) - - encoder_final = self._bridge(join_state) - - return encoder_final, prop_state, lengths - - def _initialize_bridge(self, rnn_type, hidden_size, num_layers): - - # LSTM has hidden and cell state, other only one - number_of_states = 2 if rnn_type == "LSTM" else 1 - # Total number of states - self.total_hidden_dim = hidden_size * num_layers - - # Build a linear layer for each - self.bridge = nn.ModuleList( - [nn.Linear(self.total_hidden_dim, self.total_hidden_dim, bias=True) for _ in range(number_of_states)] - ) - - def _bridge(self, hidden): - """Forward hidden state through bridge.""" - - def bottle_hidden(linear, states): - """ - Transform from 3D to 2D, apply linear and return initial size - """ - size = states.size() - result = linear(states.view(-1, self.total_hidden_dim)) - return F.leaky_relu(result).view(size) - - if isinstance(hidden, tuple): # LSTM - outs = tuple([bottle_hidden(layer, hidden[ix]) for ix, layer in enumerate(self.bridge)]) - else: - outs = bottle_hidden(self.bridge[0], hidden) - return outs diff --git a/onmt/encoders/rnn_encoder.py b/onmt/encoders/rnn_encoder.py deleted file mode 100644 index 78271050..00000000 --- a/onmt/encoders/rnn_encoder.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Define RNN-based encoders.""" -import torch.nn as nn -import torch.nn.functional as F - -from torch.nn.utils.rnn import pack_padded_sequence as pack -from torch.nn.utils.rnn import pad_packed_sequence as unpack - -from onmt.encoders.encoder import EncoderBase -from onmt.utils.rnn_factory import rnn_factory - - -class RNNEncoder(EncoderBase): - """A generic recurrent neural network encoder. - - Args: - rnn_type (str): - style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] - bidirectional (bool) : use a bidirectional RNN - num_layers (int) : number of stacked layers - hidden_size (int) : hidden size of each layer - dropout (float) : dropout value for :class:`torch.nn.Dropout` - embeddings (onmt.modules.Embeddings): embedding module to use - """ - - def __init__( - self, rnn_type, bidirectional, num_layers, hidden_size, dropout=0.0, embeddings=None, use_bridge=False - ): - super(RNNEncoder, self).__init__() - assert embeddings is not None - - num_directions = 2 if bidirectional else 1 - assert hidden_size % num_directions == 0 - hidden_size = hidden_size // num_directions - self.embeddings = embeddings - - self.rnn, self.no_pack_padded_seq = rnn_factory( - rnn_type, - input_size=embeddings.embedding_size, - hidden_size=hidden_size, - num_layers=num_layers, - dropout=dropout, - bidirectional=bidirectional, - ) - - # Initialize the bridge layer - self.use_bridge = use_bridge - if self.use_bridge: - self._initialize_bridge(rnn_type, hidden_size, num_layers) - - @classmethod - def from_opt(cls, opt, embeddings): - """Alternate constructor.""" - return cls( - opt.rnn_type, - opt.brnn, - opt.enc_layers, - opt.enc_rnn_size, - opt.dropout[0] if type(opt.dropout) is list else opt.dropout, - embeddings, - opt.bridge, - ) - - def forward(self, src, lengths=None): - """See :func:`EncoderBase.forward()`""" - self._check_args(src, lengths) - - emb = self.embeddings(src) - # s_len, batch, emb_dim = emb.size() - - packed_emb = emb - if lengths is not None and not self.no_pack_padded_seq: - # Lengths data is wrapped inside a Tensor. - lengths_list = lengths.view(-1).tolist() - packed_emb = pack(emb, lengths_list) - - memory_bank, encoder_final = self.rnn(packed_emb) - - if lengths is not None and not self.no_pack_padded_seq: - memory_bank = unpack(memory_bank)[0] - - if self.use_bridge: - encoder_final = self._bridge(encoder_final) - return encoder_final, memory_bank, lengths - - def _initialize_bridge(self, rnn_type, hidden_size, num_layers): - - # LSTM has hidden and cell state, other only one - number_of_states = 2 if rnn_type == "LSTM" else 1 - # Total number of states - self.total_hidden_dim = hidden_size * num_layers - - # Build a linear layer for each - self.bridge = nn.ModuleList( - [nn.Linear(self.total_hidden_dim, self.total_hidden_dim, bias=True) for _ in range(number_of_states)] - ) - - def _bridge(self, hidden): - """Forward hidden state through bridge.""" - - def bottle_hidden(linear, states): - """ - Transform from 3D to 2D, apply linear and return initial size - """ - size = states.size() - result = linear(states.view(-1, self.total_hidden_dim)) - return F.relu(result).view(size) - - if isinstance(hidden, tuple): # LSTM - outs = tuple([bottle_hidden(layer, hidden[ix]) for ix, layer in enumerate(self.bridge)]) - else: - outs = bottle_hidden(self.bridge[0], hidden) - return outs - - def update_dropout(self, dropout): - self.rnn.dropout = dropout diff --git a/onmt/models/__init__.py b/onmt/models/__init__.py deleted file mode 100644 index 7543dfe3..00000000 --- a/onmt/models/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Module defining models.""" -from onmt.models.model_saver import build_model_saver, ModelSaver -from onmt.models.model import NMTModel, LanguageModel - -__all__ = ["build_model_saver", "ModelSaver", "NMTModel", "LanguageModel"] diff --git a/onmt/models/sru.py b/onmt/models/sru.py deleted file mode 100644 index 4df30ef0..00000000 --- a/onmt/models/sru.py +++ /dev/null @@ -1,647 +0,0 @@ -""" SRU Implementation """ -# flake8: noqa - -import subprocess -import platform -import os -import re -import configargparse -import torch -import torch.nn as nn -from torch.autograd import Function -from torch.cuda.amp import custom_fwd, custom_bwd -from collections import namedtuple - - -# For command-line option parsing -class CheckSRU(configargparse.Action): - def __init__(self, option_strings, dest, **kwargs): - super(CheckSRU, self).__init__(option_strings, dest, **kwargs) - - def __call__(self, parser, namespace, values, option_string=None): - if values == 'SRU': - check_sru_requirement(abort=True) - # Check pass, set the args. - setattr(namespace, self.dest, values) - - -# This SRU version implements its own cuda-level optimization, -# so it requires that: -# 1. `cupy` and `pynvrtc` python package installed. -# 2. pytorch is built with cuda support. -# 3. library path set: export LD_LIBRARY_PATH=. -def check_sru_requirement(abort=False): - """ - Return True if check pass; if check fails and abort is True, - raise an Exception, othereise return False. - """ - - # Check 1. - try: - if platform.system() == 'Windows': - subprocess.check_output('pip freeze | findstr cupy', shell=True) - subprocess.check_output('pip freeze | findstr pynvrtc', shell=True) - else: # Unix-like systems - subprocess.check_output('pip freeze | grep -w cupy', shell=True) - subprocess.check_output('pip freeze | grep -w pynvrtc', shell=True) - except subprocess.CalledProcessError: - if not abort: - return False - raise AssertionError("Using SRU requires 'cupy' and 'pynvrtc' python packages installed.") - - # Check 2. - if torch.cuda.is_available() is False: - if not abort: - return False - raise AssertionError("Using SRU requires pytorch built with cuda.") - - # Check 3. - pattern = re.compile(".*cuda/lib.*") - ld_path = os.getenv('LD_LIBRARY_PATH', "") - if re.match(pattern, ld_path) is None: - if not abort: - return False - raise AssertionError( - "Using SRU requires setting cuda lib path, e.g. export LD_LIBRARY_PATH=/usr/local/cuda/lib64." - ) - - return True - - -SRU_CODE = """ -extern "C" { - __forceinline__ __device__ float sigmoidf(float x) - { - return 1.f / (1.f + expf(-x)); - } - __forceinline__ __device__ float reluf(float x) - { - return (x > 0.f) ? x : 0.f; - } - __global__ void sru_fwd(const float * __restrict__ u, - const float * __restrict__ x, - const float * __restrict__ bias, - const float * __restrict__ init, - const float * __restrict__ mask_h, - const int len, const int batch, - const int d, const int k, - float * __restrict__ h, - float * __restrict__ c, - const int activation_type) - { - assert ((k == 3) || (x == NULL)); - int ncols = batch*d; - int col = blockIdx.x * blockDim.x + threadIdx.x; - if (col >= ncols) return; - int ncols_u = ncols*k; - int ncols_x = (k == 3) ? ncols : ncols_u; - const float bias1 = *(bias + (col%d)); - const float bias2 = *(bias + (col%d) + d); - const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); - float cur = *(init + col); - const float *up = u + (col*k); - const float *xp = (k == 3) ? (x + col) : (up + 3); - float *cp = c + col; - float *hp = h + col; - for (int row = 0; row < len; ++row) - { - float g1 = sigmoidf((*(up+1))+bias1); - float g2 = sigmoidf((*(up+2))+bias2); - cur = (cur-(*up))*g1 + (*up); - *cp = cur; - float val = (activation_type == 1) ? tanh(cur) : ( - (activation_type == 2) ? reluf(cur) : cur - ); - *hp = (val*mask-(*xp))*g2 + (*xp); - up += ncols_u; - xp += ncols_x; - cp += ncols; - hp += ncols; - } - } - __global__ void sru_bwd(const float * __restrict__ u, - const float * __restrict__ x, - const float * __restrict__ bias, - const float * __restrict__ init, - const float * __restrict__ mask_h, - const float * __restrict__ c, - const float * __restrict__ grad_h, - const float * __restrict__ grad_last, - const int len, - const int batch, const int d, const int k, - float * __restrict__ grad_u, - float * __restrict__ grad_x, - float * __restrict__ grad_bias, - float * __restrict__ grad_init, - int activation_type) - { - assert((k == 3) || (x == NULL)); - assert((k == 3) || (grad_x == NULL)); - int ncols = batch*d; - int col = blockIdx.x * blockDim.x + threadIdx.x; - if (col >= ncols) return; - int ncols_u = ncols*k; - int ncols_x = (k == 3) ? ncols : ncols_u; - const float bias1 = *(bias + (col%d)); - const float bias2 = *(bias + (col%d) + d); - const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); - float gbias1 = 0; - float gbias2 = 0; - float cur = *(grad_last + col); - const float *up = u + (col*k) + (len-1)*ncols_u; - const float *xp = (k == 3) ? (x + col + (len-1)*ncols) : (up + 3); - const float *cp = c + col + (len-1)*ncols; - const float *ghp = grad_h + col + (len-1)*ncols; - float *gup = grad_u + (col*k) + (len-1)*ncols_u; - float *gxp = (k == 3) ? (grad_x + col + (len-1)*ncols) : (gup + 3); - for (int row = len-1; row >= 0; --row) - { - const float g1 = sigmoidf((*(up+1))+bias1); - const float g2 = sigmoidf((*(up+2))+bias2); - const float c_val = (activation_type == 1) ? tanh(*cp) : ( - (activation_type == 2) ? reluf(*cp) : (*cp) - ); - const float x_val = *xp; - const float u_val = *up; - const float prev_c_val = (row>0) ? (*(cp-ncols)) : (*(init+col)); - const float gh_val = *ghp; - // h = c*g2 + x*(1-g2) = (c-x)*g2 + x - // c = c'*g1 + g0*(1-g1) = (c'-g0)*g1 + g0 - // grad wrt x - *gxp = gh_val*(1-g2); - // grad wrt g2, u2 and bias2 - float gg2 = gh_val*(c_val*mask-x_val)*(g2*(1-g2)); - *(gup+2) = gg2; - gbias2 += gg2; - // grad wrt c - const float tmp = (activation_type == 1) ? (g2*(1-c_val*c_val)) : ( - ((activation_type == 0) || (c_val > 0)) ? g2 : 0.f - ); - const float gc = gh_val*mask*tmp + cur; - // grad wrt u0 - *gup = gc*(1-g1); - // grad wrt g1, u1, and bias1 - float gg1 = gc*(prev_c_val-u_val)*(g1*(1-g1)); - *(gup+1) = gg1; - gbias1 += gg1; - // grad wrt c' - cur = gc*g1; - up -= ncols_u; - xp -= ncols_x; - cp -= ncols; - gup -= ncols_u; - gxp -= ncols_x; - ghp -= ncols; - } - *(grad_bias + col) = gbias1; - *(grad_bias + col + ncols) = gbias2; - *(grad_init +col) = cur; - } - __global__ void sru_bi_fwd(const float * __restrict__ u, - const float * __restrict__ x, - const float * __restrict__ bias, - const float * __restrict__ init, - const float * __restrict__ mask_h, - const int len, const int batch, - const int d, const int k, - float * __restrict__ h, - float * __restrict__ c, - const int activation_type) - { - assert ((k == 3) || (x == NULL)); - assert ((k == 3) || (k == 4)); - int ncols = batch*d*2; - int col = blockIdx.x * blockDim.x + threadIdx.x; - if (col >= ncols) return; - int ncols_u = ncols*k; - int ncols_x = (k == 3) ? ncols : ncols_u; - const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); - float cur = *(init + col); - const int d2 = d*2; - const bool flip = (col%d2) >= d; - const float bias1 = *(bias + (col%d2)); - const float bias2 = *(bias + (col%d2) + d2); - const float *up = u + (col*k); - const float *xp = (k == 3) ? (x + col) : (up + 3); - float *cp = c + col; - float *hp = h + col; - if (flip) { - up += (len-1)*ncols_u; - xp += (len-1)*ncols_x; - cp += (len-1)*ncols; - hp += (len-1)*ncols; - } - int ncols_u_ = flip ? -ncols_u : ncols_u; - int ncols_x_ = flip ? -ncols_x : ncols_x; - int ncols_ = flip ? -ncols : ncols; - for (int cnt = 0; cnt < len; ++cnt) - { - float g1 = sigmoidf((*(up+1))+bias1); - float g2 = sigmoidf((*(up+2))+bias2); - cur = (cur-(*up))*g1 + (*up); - *cp = cur; - float val = (activation_type == 1) ? tanh(cur) : ( - (activation_type == 2) ? reluf(cur) : cur - ); - *hp = (val*mask-(*xp))*g2 + (*xp); - up += ncols_u_; - xp += ncols_x_; - cp += ncols_; - hp += ncols_; - } - } - __global__ void sru_bi_bwd(const float * __restrict__ u, - const float * __restrict__ x, - const float * __restrict__ bias, - const float * __restrict__ init, - const float * __restrict__ mask_h, - const float * __restrict__ c, - const float * __restrict__ grad_h, - const float * __restrict__ grad_last, - const int len, const int batch, - const int d, const int k, - float * __restrict__ grad_u, - float * __restrict__ grad_x, - float * __restrict__ grad_bias, - float * __restrict__ grad_init, - int activation_type) - { - assert((k == 3) || (x == NULL)); - assert((k == 3) || (grad_x == NULL)); - assert((k == 3) || (k == 4)); - int ncols = batch*d*2; - int col = blockIdx.x * blockDim.x + threadIdx.x; - if (col >= ncols) return; - int ncols_u = ncols*k; - int ncols_x = (k == 3) ? ncols : ncols_u; - const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); - float gbias1 = 0; - float gbias2 = 0; - float cur = *(grad_last + col); - const int d2 = d*2; - const bool flip = ((col%d2) >= d); - const float bias1 = *(bias + (col%d2)); - const float bias2 = *(bias + (col%d2) + d2); - const float *up = u + (col*k); - const float *xp = (k == 3) ? (x + col) : (up + 3); - const float *cp = c + col; - const float *ghp = grad_h + col; - float *gup = grad_u + (col*k); - float *gxp = (k == 3) ? (grad_x + col) : (gup + 3); - if (!flip) { - up += (len-1)*ncols_u; - xp += (len-1)*ncols_x; - cp += (len-1)*ncols; - ghp += (len-1)*ncols; - gup += (len-1)*ncols_u; - gxp += (len-1)*ncols_x; - } - int ncols_u_ = flip ? -ncols_u : ncols_u; - int ncols_x_ = flip ? -ncols_x : ncols_x; - int ncols_ = flip ? -ncols : ncols; - for (int cnt = 0; cnt < len; ++cnt) - { - const float g1 = sigmoidf((*(up+1))+bias1); - const float g2 = sigmoidf((*(up+2))+bias2); - const float c_val = (activation_type == 1) ? tanh(*cp) : ( - (activation_type == 2) ? reluf(*cp) : (*cp) - ); - const float x_val = *xp; - const float u_val = *up; - const float prev_c_val = (cnt 0)) ? g2 : 0.f - ); - const float gc = gh_val*mask*tmp + cur; - // grad wrt u0 - *gup = gc*(1-g1); - // grad wrt g1, u1, and bias1 - float gg1 = gc*(prev_c_val-u_val)*(g1*(1-g1)); - *(gup+1) = gg1; - gbias1 += gg1; - // grad wrt c' - cur = gc*g1; - up -= ncols_u_; - xp -= ncols_x_; - cp -= ncols_; - gup -= ncols_u_; - gxp -= ncols_x_; - ghp -= ncols_; - } - *(grad_bias + col) = gbias1; - *(grad_bias + col + ncols) = gbias2; - *(grad_init +col) = cur; - } -} -""" -SRU_FWD_FUNC, SRU_BWD_FUNC = None, None -SRU_BiFWD_FUNC, SRU_BiBWD_FUNC = None, None -SRU_STREAM = None - - -def load_sru_mod(): - global SRU_FWD_FUNC, SRU_BWD_FUNC, SRU_BiFWD_FUNC, SRU_BiBWD_FUNC - global SRU_STREAM - if check_sru_requirement(): - from cupy.cuda import function - from pynvrtc.compiler import Program - - # This sets up device to use. - device = torch.device("cuda") - tmp_ = torch.rand(1, 1).to(device) - - sru_prog = Program(SRU_CODE.encode('utf-8'), 'sru_prog.cu'.encode('utf-8')) - sru_ptx = sru_prog.compile() - sru_mod = function.Module() - sru_mod.load(bytes(sru_ptx.encode())) - - SRU_FWD_FUNC = sru_mod.get_function('sru_fwd') - SRU_BWD_FUNC = sru_mod.get_function('sru_bwd') - SRU_BiFWD_FUNC = sru_mod.get_function('sru_bi_fwd') - SRU_BiBWD_FUNC = sru_mod.get_function('sru_bi_bwd') - - stream = namedtuple('Stream', ['ptr']) - SRU_STREAM = stream(ptr=torch.cuda.current_stream().cuda_stream) - - -class SRU_Compute(Function): - def __init__(self, activation_type, d_out, bidirectional=False): - SRU_Compute.maybe_load_sru_mod() - super(SRU_Compute, self).__init__() - self.activation_type = activation_type - self.d_out = d_out - self.bidirectional = bidirectional - - @staticmethod - def maybe_load_sru_mod(): - global SRU_FWD_FUNC - - if SRU_FWD_FUNC is None: - load_sru_mod() - - @custom_fwd - def forward(self, u, x, bias, init=None, mask_h=None): - bidir = 2 if self.bidirectional else 1 - length = x.size(0) if x.dim() == 3 else 1 - batch = x.size(-2) - d = self.d_out - k = u.size(-1) // d - k_ = k // 2 if self.bidirectional else k - ncols = batch * d * bidir - thread_per_block = min(512, ncols) - num_block = (ncols - 1) // thread_per_block + 1 - - init_ = x.new(ncols).zero_() if init is None else init - size = (length, batch, d * bidir) if x.dim() == 3 else (batch, d * bidir) - c = x.new(*size) - h = x.new(*size) - - FUNC = SRU_FWD_FUNC if not self.bidirectional else SRU_BiFWD_FUNC - FUNC( - args=[ - u.contiguous().data_ptr(), - x.contiguous().data_ptr() if k_ == 3 else 0, - bias.data_ptr(), - init_.contiguous().data_ptr(), - mask_h.data_ptr() if mask_h is not None else 0, - length, - batch, - d, - k_, - h.data_ptr(), - c.data_ptr(), - self.activation_type, - ], - block=(thread_per_block, 1, 1), - grid=(num_block, 1, 1), - stream=SRU_STREAM, - ) - - self.save_for_backward(u, x, bias, init, mask_h) - self.intermediate = c - if x.dim() == 2: - last_hidden = c - elif self.bidirectional: - # -> directions x batch x dim - last_hidden = torch.stack((c[-1, :, :d], c[0, :, d:])) - else: - last_hidden = c[-1] - return h, last_hidden - - @custom_bwd - def backward(self, grad_h, grad_last): - if self.bidirectional: - grad_last = torch.cat((grad_last[0], grad_last[1]), 1) - bidir = 2 if self.bidirectional else 1 - u, x, bias, init, mask_h = self.saved_tensors - c = self.intermediate - length = x.size(0) if x.dim() == 3 else 1 - batch = x.size(-2) - d = self.d_out - k = u.size(-1) // d - k_ = k // 2 if self.bidirectional else k - ncols = batch * d * bidir - thread_per_block = min(512, ncols) - num_block = (ncols - 1) // thread_per_block + 1 - - init_ = x.new(ncols).zero_() if init is None else init - grad_u = u.new(*u.size()) - grad_bias = x.new(2, batch, d * bidir) - grad_init = x.new(batch, d * bidir) - - # For DEBUG - # size = (length, batch, x.size(-1)) \ - # if x.dim() == 3 else (batch, x.size(-1)) - # grad_x = x.new(*x.size()) if k_ == 3 else x.new(*size).zero_() - - # Normal use - grad_x = x.new(*x.size()) if k_ == 3 else None - - FUNC = SRU_BWD_FUNC if not self.bidirectional else SRU_BiBWD_FUNC - FUNC( - args=[ - u.contiguous().data_ptr(), - x.contiguous().data_ptr() if k_ == 3 else 0, - bias.data_ptr(), - init_.contiguous().data_ptr(), - mask_h.data_ptr() if mask_h is not None else 0, - c.data_ptr(), - grad_h.contiguous().data_ptr(), - grad_last.contiguous().data_ptr(), - length, - batch, - d, - k_, - grad_u.data_ptr(), - grad_x.data_ptr() if k_ == 3 else 0, - grad_bias.data_ptr(), - grad_init.data_ptr(), - self.activation_type, - ], - block=(thread_per_block, 1, 1), - grid=(num_block, 1, 1), - stream=SRU_STREAM, - ) - return grad_u, grad_x, grad_bias.sum(1).view(-1), grad_init, None - - -class SRUCell(nn.Module): - def __init__(self, n_in, n_out, dropout=0, rnn_dropout=0, bidirectional=False, use_tanh=1, use_relu=0): - super(SRUCell, self).__init__() - self.n_in = n_in - self.n_out = n_out - self.rnn_dropout = rnn_dropout - self.dropout = dropout - self.bidirectional = bidirectional - self.activation_type = 2 if use_relu else (1 if use_tanh else 0) - - out_size = n_out * 2 if bidirectional else n_out - k = 4 if n_in != out_size else 3 - self.size_per_dir = n_out * k - self.weight = nn.Parameter(torch.Tensor(n_in, self.size_per_dir * 2 if bidirectional else self.size_per_dir)) - self.bias = nn.Parameter(torch.Tensor(n_out * 4 if bidirectional else n_out * 2)) - self.init_weight() - - def init_weight(self): - val_range = (3.0 / self.n_in) ** 0.5 - self.weight.data.uniform_(-val_range, val_range) - self.bias.data.zero_() - - def set_bias(self, bias_val=0): - n_out = self.n_out - if self.bidirectional: - self.bias.data[n_out * 2 :].zero_().add_(bias_val) - else: - self.bias.data[n_out:].zero_().add_(bias_val) - - def forward(self, input, c0=None): - assert input.dim() == 2 or input.dim() == 3 - n_in, n_out = self.n_in, self.n_out - batch = input.size(-2) - if c0 is None: - c0 = input.data.new(batch, n_out if not self.bidirectional else n_out * 2).zero_() - - if self.training and (self.rnn_dropout > 0): - mask = self.get_dropout_mask_((batch, n_in), self.rnn_dropout) - x = input * mask.expand_as(input) - else: - x = input - - x_2d = x if x.dim() == 2 else x.contiguous().view(-1, n_in) - u = x_2d.mm(self.weight) - - if self.training and (self.dropout > 0): - bidir = 2 if self.bidirectional else 1 - mask_h = self.get_dropout_mask_((batch, n_out * bidir), self.dropout) - h, c = SRU_Compute(self.activation_type, n_out, self.bidirectional)(u, input, self.bias, c0, mask_h) - else: - h, c = SRU_Compute(self.activation_type, n_out, self.bidirectional)(u, input, self.bias, c0) - - return h, c - - def get_dropout_mask_(self, size, p): - w = self.weight.data - return w.new(*size).bernoulli_(1 - p).div_(1 - p) - - -class SRU(nn.Module): - """ - Implementation of "Training RNNs as Fast as CNNs" - :cite:`DBLP:journals/corr/abs-1709-02755` - - TODO: turn to pytorch's implementation when it is available. - - This implementation is adpoted from the author of the paper: - https://github.com/taolei87/sru/blob/master/cuda_functional.py. - - Args: - input_size (int): input to model - hidden_size (int): hidden dimension - num_layers (int): number of layers - dropout (float): dropout to use (stacked) - rnn_dropout (float): dropout to use (recurrent) - bidirectional (bool): bidirectional - use_tanh (bool): activation - use_relu (bool): activation - """ - - def __init__( - self, - input_size, - hidden_size, - num_layers=2, - dropout=0, - rnn_dropout=0, - bidirectional=False, - use_tanh=1, - use_relu=0, - ): - # An entry check here, will catch on train side and translate side - # if requirements are not satisfied. - check_sru_requirement(abort=True) - super(SRU, self).__init__() - self.n_in = input_size - self.n_out = hidden_size - self.depth = num_layers - self.dropout = dropout - self.rnn_dropout = rnn_dropout - self.rnn_lst = nn.ModuleList() - self.bidirectional = bidirectional - self.out_size = hidden_size * 2 if bidirectional else hidden_size - - for i in range(num_layers): - sru_cell = SRUCell( - n_in=self.n_in if i == 0 else self.out_size, - n_out=self.n_out, - dropout=dropout if i + 1 != num_layers else 0, - rnn_dropout=rnn_dropout, - bidirectional=bidirectional, - use_tanh=use_tanh, - use_relu=use_relu, - ) - self.rnn_lst.append(sru_cell) - - def set_bias(self, bias_val=0): - for l in self.rnn_lst: - l.set_bias(bias_val) - - def forward(self, input, c0=None, return_hidden=True): - assert input.dim() == 3 # (len, batch, n_in) - dir_ = 2 if self.bidirectional else 1 - if c0 is None: - zeros = input.data.new(input.size(1), self.n_out * dir_).zero_() - c0 = [zeros for i in range(self.depth)] - else: - if isinstance(c0, tuple): - # RNNDecoderState wraps hidden as a tuple. - c0 = c0[0] - assert c0.dim() == 3 # (depth, batch, dir_*n_out) - c0 = [h.squeeze(0) for h in c0.chunk(self.depth, 0)] - - prevx = input - lstc = [] - for i, rnn in enumerate(self.rnn_lst): - h, c = rnn(prevx, c0[i]) - prevx = h - lstc.append(c) - - if self.bidirectional: - # fh -> (layers*directions) x batch x dim - fh = torch.cat(lstc) - else: - fh = torch.stack(lstc) - - if return_hidden: - return prevx, fh - else: - return prevx diff --git a/onmt/models/stacked_rnn.py b/onmt/models/stacked_rnn.py deleted file mode 100644 index cb201f04..00000000 --- a/onmt/models/stacked_rnn.py +++ /dev/null @@ -1,65 +0,0 @@ -""" Implementation of ONMT RNN for Input Feeding Decoding """ -import torch -import torch.nn as nn - - -class StackedLSTM(nn.Module): - """ - Our own implementation of stacked LSTM. - Needed for the decoder, because we do input feeding. - """ - - def __init__(self, num_layers, input_size, rnn_size, dropout): - super(StackedLSTM, self).__init__() - self.dropout = nn.Dropout(dropout) - self.num_layers = num_layers - self.layers = nn.ModuleList() - - for _ in range(num_layers): - self.layers.append(nn.LSTMCell(input_size, rnn_size)) - input_size = rnn_size - - def forward(self, input_feed, hidden): - h_0, c_0 = hidden - h_1, c_1 = [], [] - for i, layer in enumerate(self.layers): - h_1_i, c_1_i = layer(input_feed, (h_0[i], c_0[i])) - input_feed = h_1_i - if i + 1 != self.num_layers: - input_feed = self.dropout(input_feed) - h_1 += [h_1_i] - c_1 += [c_1_i] - - h_1 = torch.stack(h_1) - c_1 = torch.stack(c_1) - - return input_feed, (h_1, c_1) - - -class StackedGRU(nn.Module): - """ - Our own implementation of stacked GRU. - Needed for the decoder, because we do input feeding. - """ - - def __init__(self, num_layers, input_size, rnn_size, dropout): - super(StackedGRU, self).__init__() - self.dropout = nn.Dropout(dropout) - self.num_layers = num_layers - self.layers = nn.ModuleList() - - for _ in range(num_layers): - self.layers.append(nn.GRUCell(input_size, rnn_size)) - input_size = rnn_size - - def forward(self, input_feed, hidden): - h_1 = [] - for i, layer in enumerate(self.layers): - h_1_i = layer(input_feed, hidden[0][i]) - input_feed = h_1_i - if i + 1 != self.num_layers: - input_feed = self.dropout(input_feed) - h_1 += [h_1_i] - - h_1 = torch.stack(h_1) - return input_feed, (h_1,) diff --git a/onmt/modules/__init__.py b/onmt/modules/__init__.py deleted file mode 100644 index 45122e53..00000000 --- a/onmt/modules/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" Attention and normalization modules """ -from onmt.modules.util_class import Elementwise -from onmt.modules.gate import context_gate_factory, ContextGate -from onmt.modules.global_attention import GlobalAttention -from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention -from onmt.modules.copy_generator import ( - CopyGenerator, - CopyGeneratorLoss, - CopyGeneratorLossCompute, - CopyGeneratorLMLossCompute, -) -from onmt.modules.multi_headed_attn import MultiHeadedAttention -from onmt.modules.embeddings import Embeddings, PositionalEncoding -from onmt.modules.weight_norm import WeightNormConv2d -from onmt.modules.average_attn import AverageAttention -from onmt.modules.stable_embeddings import StableEmbedding - -__all__ = [ - "Elementwise", - "context_gate_factory", - "ContextGate", - "GlobalAttention", - "ConvMultiStepAttention", - "CopyGenerator", - "CopyGeneratorLoss", - "CopyGeneratorLossCompute", - "MultiHeadedAttention", - "Embeddings", - "PositionalEncoding", - "WeightNormConv2d", - "AverageAttention", - "CopyGeneratorLMLossCompute", - "StableEmbedding", -] diff --git a/onmt/modules/conv_multi_step_attention.py b/onmt/modules/conv_multi_step_attention.py deleted file mode 100644 index fe1fd4b4..00000000 --- a/onmt/modules/conv_multi_step_attention.py +++ /dev/null @@ -1,76 +0,0 @@ -""" Multi Step Attention for CNN """ -import torch -import torch.nn as nn -import torch.nn.functional as F -from onmt.utils.misc import aeq - - -SCALE_WEIGHT = 0.5**0.5 - - -def seq_linear(linear, x): - """linear transform for 3-d tensor""" - batch, hidden_size, length, _ = x.size() - h = linear(torch.transpose(x, 1, 2).contiguous().view(batch * length, hidden_size)) - return torch.transpose(h.view(batch, length, hidden_size, 1), 1, 2) - - -class ConvMultiStepAttention(nn.Module): - """ - Conv attention takes a key matrix, a value matrix and a query vector. - Attention weight is calculated by key matrix with the query vector - and sum on the value matrix. And the same operation is applied - in each decode conv layer. - """ - - def __init__(self, input_size): - super(ConvMultiStepAttention, self).__init__() - self.linear_in = nn.Linear(input_size, input_size) - self.mask = None - - def apply_mask(self, mask): - """Apply mask""" - self.mask = mask - - def forward(self, base_target_emb, input_from_dec, encoder_out_top, encoder_out_combine): - """ - Args: - base_target_emb: target emb tensor - input_from_dec: output of decode conv - encoder_out_top: the key matrix for calculation of attetion weight, - which is the top output of encode conv - encoder_out_combine: - the value matrix for the attention-weighted sum, - which is the combination of base emb and top output of encode - """ - - # checks - # batch, channel, height, width = base_target_emb.size() - batch, _, height, _ = base_target_emb.size() - # batch_, channel_, height_, width_ = input_from_dec.size() - batch_, _, height_, _ = input_from_dec.size() - aeq(batch, batch_) - aeq(height, height_) - - # enc_batch, enc_channel, enc_height = encoder_out_top.size() - enc_batch, _, enc_height = encoder_out_top.size() - # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() - enc_batch_, _, enc_height_ = encoder_out_combine.size() - - aeq(enc_batch, enc_batch_) - aeq(enc_height, enc_height_) - - preatt = seq_linear(self.linear_in, input_from_dec) - target = (base_target_emb + preatt) * SCALE_WEIGHT - target = torch.squeeze(target, 3) - target = torch.transpose(target, 1, 2) - pre_attn = torch.bmm(target, encoder_out_top) - - if self.mask is not None: - pre_attn.data.masked_fill_(self.mask, -float('inf')) - - attn = F.softmax(pre_attn, dim=2) - - context_output = torch.bmm(attn, torch.transpose(encoder_out_combine, 1, 2)) - context_output = torch.transpose(torch.unsqueeze(context_output, 3), 1, 2) - return context_output, attn diff --git a/onmt/modules/copy_generator.py b/onmt/modules/copy_generator.py deleted file mode 100644 index 3e426119..00000000 --- a/onmt/modules/copy_generator.py +++ /dev/null @@ -1,264 +0,0 @@ -import torch -import torch.nn as nn - -from onmt.utils.misc import aeq -from onmt.utils.loss import CommonLossCompute - - -def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs=None, batch_dim=1, batch_offset=None): - """ - Given scores from an expanded dictionary - corresponeding to a batch, sums together copies, - with a dictionary word when it is ambiguous. - """ - offset = len(tgt_vocab) - for b in range(scores.size(batch_dim)): - blank = [] - fill = [] - - if src_vocabs is None: - src_vocab = batch.src_ex_vocab[b] - else: - batch_id = batch_offset[b] if batch_offset is not None else b - index = batch.indices.data[batch_id] - src_vocab = src_vocabs[index] - - for i in range(1, len(src_vocab)): - sw = src_vocab.itos[i] - ti = tgt_vocab.stoi[sw] - if ti != 0: - blank.append(offset + i) - fill.append(ti) - if blank: - blank = torch.Tensor(blank).type_as(batch.indices.data) - fill = torch.Tensor(fill).type_as(batch.indices.data) - score = scores[:, b] if batch_dim == 1 else scores[b] - score.index_add_(1, fill, score.index_select(1, blank)) - score.index_fill_(1, blank, 1e-10) - return scores - - -class CopyGenerator(nn.Module): - """An implementation of pointer-generator networks - :cite:`DBLP:journals/corr/SeeLM17`. - - These networks consider copying words - directly from the source sequence. - - The copy generator is an extended version of the standard - generator that computes three values. - - * :math:`p_{softmax}` the standard softmax over `tgt_dict` - * :math:`p(z)` the probability of copying a word from - the source - * :math:`p_{copy}` the probility of copying a particular word. - taken from the attention distribution directly. - - The model returns a distribution over the extend dictionary, - computed as - - :math:`p(w) = p(z=1) p_{copy}(w) + p(z=0) p_{softmax}(w)` - - - .. mermaid:: - - graph BT - A[input] - S[src_map] - B[softmax] - BB[switch] - C[attn] - D[copy] - O[output] - A --> B - A --> BB - S --> D - C --> D - D --> O - B --> O - BB --> O - - - Args: - input_size (int): size of input representation - output_size (int): size of output vocabulary - pad_idx (int) - """ - - def __init__(self, input_size, output_size, pad_idx): - super(CopyGenerator, self).__init__() - self.linear = nn.Linear(input_size, output_size) - self.linear_copy = nn.Linear(input_size, 1) - self.pad_idx = pad_idx - - def forward(self, hidden, attn, src_map): - """ - Compute a distribution over the target dictionary - extended by the dynamic dictionary implied by copying - source words. - - Args: - hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` - attn (FloatTensor): attn for each ``(batch x tlen, slen)`` - src_map (FloatTensor): - A sparse indicator matrix mapping each source word to - its index in the "extended" vocab containing. - ``(src_len, batch, extra_words)`` - """ - - # CHECKS - batch_by_tlen, _ = hidden.size() - batch_by_tlen_, slen = attn.size() - slen_, batch, cvocab = src_map.size() - aeq(batch_by_tlen, batch_by_tlen_) - aeq(slen, slen_) - - # Original probabilities. - logits = self.linear(hidden) - logits[:, self.pad_idx] = -float('inf') - prob = torch.softmax(logits, 1) - - # Probability of copying p(z=1) batch. - p_copy = torch.sigmoid(self.linear_copy(hidden)) - # Probability of not copying: p_{word}(w) * (1 - p(z)) - out_prob = torch.mul(prob, 1 - p_copy) - mul_attn = torch.mul(attn, p_copy) - copy_prob = torch.bmm(mul_attn.view(-1, batch, slen).transpose(0, 1), src_map.transpose(0, 1)).transpose(0, 1) - copy_prob = copy_prob.contiguous().view(-1, cvocab) - return torch.cat([out_prob, copy_prob], 1) - - -class CopyGeneratorLoss(nn.Module): - """Copy generator criterion.""" - - def __init__(self, vocab_size, force_copy, unk_index=0, ignore_index=-100, eps=1e-20): - super(CopyGeneratorLoss, self).__init__() - self.force_copy = force_copy - self.eps = eps - self.vocab_size = vocab_size - self.ignore_index = ignore_index - self.unk_index = unk_index - - def forward(self, scores, align, target): - """ - Args: - scores (FloatTensor): ``(batch_size*tgt_len)`` x dynamic vocab size - whose sum along dim 1 is less than or equal to 1, i.e. cols - softmaxed. - align (LongTensor): ``(batch_size x tgt_len)`` - target (LongTensor): ``(batch_size x tgt_len)`` - """ - # probabilities assigned by the model to the gold targets - vocab_probs = scores.gather(1, target.unsqueeze(1)).squeeze(1) - - # probability of tokens copied from source - copy_ix = align.unsqueeze(1) + self.vocab_size - copy_tok_probs = scores.gather(1, copy_ix).squeeze(1) - # Set scores for unk to 0 and add eps - copy_tok_probs[align == self.unk_index] = 0 - copy_tok_probs += self.eps # to avoid -inf logs - - # find the indices in which you do not use the copy mechanism - non_copy = align == self.unk_index - if not self.force_copy: - non_copy = non_copy | (target != self.unk_index) - - probs = torch.where(non_copy, copy_tok_probs + vocab_probs, copy_tok_probs) - - loss = -probs.log() # just NLLLoss; can the module be incorporated? - # Drop padding. - loss[target == self.ignore_index] = 0 - return loss - - -class CommonCopyGeneratorLossCompute(CommonLossCompute): - """Common Copy Generator Loss Computation.""" - - def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, lambda_coverage=0.0, tgt_shift_index=1): - super(CommonCopyGeneratorLossCompute, self).__init__( - criterion, generator, lambda_coverage=lambda_coverage, tgt_shift_index=tgt_shift_index - ) - self.tgt_vocab = tgt_vocab - self.normalize_by_length = normalize_by_length - - def _compute_loss(self, batch, output, target, copy_attn, align, std_attn=None, coverage_attn=None): - """Compute the loss. - - The args must match :func:`self._make_shard_state()`. - - Args: - batch: the current batch. - output: the predict output from the model. - target: the validate target to compare output with. - copy_attn: the copy attention value. - align: the align info. - """ - target = target.view(-1) - align = align.view(-1) - scores = self.generator(self._bottle(output), self._bottle(copy_attn), batch.src_map) - loss = self.criterion(scores, align, target) - - if self.lambda_coverage != 0.0: - coverage_loss = self._compute_coverage_loss(std_attn, coverage_attn) - loss += coverage_loss - - # this block does not depend on the loss value computed above - # and is used only for stats - scores_data = collapse_copy_scores( - self._unbottle(scores.clone(), batch.batch_size), batch, self.tgt_vocab, None - ) - scores_data = self._bottle(scores_data) - - # this block does not depend on the loss value computed above - # and is used only for stats - # Correct target copy token instead of - # tgt[i] = align[i] + len(tgt_vocab) - # for i such that tgt[i] == 0 and align[i] != 0 - target_data = target.clone() - unk = self.criterion.unk_index - correct_mask = (target_data == unk) & (align != unk) - offset_align = align[correct_mask] + len(self.tgt_vocab) - target_data[correct_mask] += offset_align - - # Compute sum of perplexities for stats - stats = self._stats(loss.sum().clone(), scores_data, target_data) - - # this part looks like it belongs in CopyGeneratorLoss - if self.normalize_by_length: - # Compute Loss as NLL divided by seq length - tgt_lens = batch.tgt[:, :, 0].ne(self.padding_idx).sum(0).float() - # Compute Total Loss per sequence in batch - loss = loss.view(-1, batch.batch_size).sum(0) - # Divide by length of each sequence and sum - loss = torch.div(loss, tgt_lens).sum() - else: - loss = loss.sum() - - return loss, stats - - def _make_shard_state(self, batch, output, range_, attns): - """See base class for args description.""" - shard_state = super(CommonCopyGeneratorLossCompute, self)._make_shard_state(batch, output, range_, attns) - - start_range = range_[0] + self.tgt_shift_index - end_range = range_[1] - shard_state.update({"copy_attn": attns.get("copy"), "align": batch.alignment[start_range:end_range]}) - return shard_state - - -class CopyGeneratorLossCompute(CommonCopyGeneratorLossCompute): - """Copy Generator Loss Computation.""" - - def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, lambda_coverage=0.0): - super(CopyGeneratorLossCompute, self).__init__( - criterion, generator, tgt_vocab, normalize_by_length, lambda_coverage=0.0, tgt_shift_index=1 - ) - - -class CopyGeneratorLMLossCompute(CommonCopyGeneratorLossCompute): - """Copy Generator LM Loss Computation.""" - - def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, lambda_coverage=0.0): - super(CopyGeneratorLMLossCompute, self).__init__( - criterion, generator, tgt_vocab, normalize_by_length, lambda_coverage=0.0, tgt_shift_index=0 - ) diff --git a/onmt/modules/gate.py b/onmt/modules/gate.py deleted file mode 100644 index 86babaf5..00000000 --- a/onmt/modules/gate.py +++ /dev/null @@ -1,76 +0,0 @@ -""" ContextGate module """ -import torch -import torch.nn as nn - - -def context_gate_factory(gate_type, embeddings_size, decoder_size, attention_size, output_size): - """Returns the correct ContextGate class""" - - gate_types = {'source': SourceContextGate, 'target': TargetContextGate, 'both': BothContextGate} - - assert gate_type in gate_types, "Not valid ContextGate type: {0}".format(gate_type) - return gate_types[gate_type](embeddings_size, decoder_size, attention_size, output_size) - - -class ContextGate(nn.Module): - """ - Context gate is a decoder module that takes as input the previous word - embedding, the current decoder state and the attention state, and - produces a gate. - The gate can be used to select the input from the target side context - (decoder state), from the source context (attention state) or both. - """ - - def __init__(self, embeddings_size, decoder_size, attention_size, output_size): - super(ContextGate, self).__init__() - input_size = embeddings_size + decoder_size + attention_size - self.gate = nn.Linear(input_size, output_size, bias=True) - self.sig = nn.Sigmoid() - self.source_proj = nn.Linear(attention_size, output_size) - self.target_proj = nn.Linear(embeddings_size + decoder_size, output_size) - - def forward(self, prev_emb, dec_state, attn_state): - input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1) - z = self.sig(self.gate(input_tensor)) - proj_source = self.source_proj(attn_state) - proj_target = self.target_proj(torch.cat((prev_emb, dec_state), dim=1)) - return z, proj_source, proj_target - - -class SourceContextGate(nn.Module): - """Apply the context gate only to the source context""" - - def __init__(self, embeddings_size, decoder_size, attention_size, output_size): - super(SourceContextGate, self).__init__() - self.context_gate = ContextGate(embeddings_size, decoder_size, attention_size, output_size) - self.tanh = nn.Tanh() - - def forward(self, prev_emb, dec_state, attn_state): - z, source, target = self.context_gate(prev_emb, dec_state, attn_state) - return self.tanh(target + z * source) - - -class TargetContextGate(nn.Module): - """Apply the context gate only to the target context""" - - def __init__(self, embeddings_size, decoder_size, attention_size, output_size): - super(TargetContextGate, self).__init__() - self.context_gate = ContextGate(embeddings_size, decoder_size, attention_size, output_size) - self.tanh = nn.Tanh() - - def forward(self, prev_emb, dec_state, attn_state): - z, source, target = self.context_gate(prev_emb, dec_state, attn_state) - return self.tanh(z * target + source) - - -class BothContextGate(nn.Module): - """Apply the context gate to both contexts""" - - def __init__(self, embeddings_size, decoder_size, attention_size, output_size): - super(BothContextGate, self).__init__() - self.context_gate = ContextGate(embeddings_size, decoder_size, attention_size, output_size) - self.tanh = nn.Tanh() - - def forward(self, prev_emb, dec_state, attn_state): - z, source, target = self.context_gate(prev_emb, dec_state, attn_state) - return self.tanh((1.0 - z) * target + z * source) diff --git a/onmt/modules/global_attention.py b/onmt/modules/global_attention.py deleted file mode 100644 index da032c94..00000000 --- a/onmt/modules/global_attention.py +++ /dev/null @@ -1,225 +0,0 @@ -"""Global attention modules (Luong / Bahdanau)""" -import torch -import torch.nn as nn -import torch.nn.functional as F - -from onmt.modules.sparse_activations import sparsemax -from onmt.utils.misc import aeq, sequence_mask - -# This class is mainly used by decoder.py for RNNs but also -# by the CNN / transformer decoder when copy attention is used -# CNN has its own attention mechanism ConvMultiStepAttention -# Transformer has its own MultiHeadedAttention - - -class GlobalAttention(nn.Module): - r""" - Global attention takes a matrix and a query vector. It - then computes a parameterized convex combination of the matrix - based on the input query. - - Constructs a unit mapping a query `q` of size `dim` - and a source matrix `H` of size `n x dim`, to an output - of size `dim`. - - - .. mermaid:: - - graph BT - A[Query] - subgraph RNN - C[H 1] - D[H 2] - E[H N] - end - F[Attn] - G[Output] - A --> F - C --> F - D --> F - E --> F - C -.-> G - D -.-> G - E -.-> G - F --> G - - All models compute the output as - :math:`c = \sum_{j=1}^{\text{SeqLength}} a_j H_j` where - :math:`a_j` is the softmax of a score function. - Then then apply a projection layer to [q, c]. - - However they - differ on how they compute the attention score. - - * Luong Attention (dot, general): - * dot: :math:`\text{score}(H_j,q) = H_j^T q` - * general: :math:`\text{score}(H_j, q) = H_j^T W_a q` - - - * Bahdanau Attention (mlp): - * :math:`\text{score}(H_j, q) = v_a^T \text{tanh}(W_a q + U_a h_j)` - - - Args: - dim (int): dimensionality of query and key - coverage (bool): use coverage term - attn_type (str): type of attention to use, options [dot,general,mlp] - attn_func (str): attention function to use, options [softmax,sparsemax] - - """ - - def __init__(self, dim, coverage=False, attn_type="dot", attn_func="softmax"): - super(GlobalAttention, self).__init__() - - self.dim = dim - assert attn_type in ["dot", "general", "mlp"], "Please select a valid attention type (got {:s}).".format( - attn_type - ) - self.attn_type = attn_type - assert attn_func in ["softmax", "sparsemax"], "Please select a valid attention function." - self.attn_func = attn_func - - if self.attn_type == "general": - self.linear_in = nn.Linear(dim, dim, bias=False) - elif self.attn_type == "mlp": - self.linear_context = nn.Linear(dim, dim, bias=False) - self.linear_query = nn.Linear(dim, dim, bias=True) - self.v = nn.Linear(dim, 1, bias=False) - # mlp wants it with bias - out_bias = self.attn_type == "mlp" - self.linear_out = nn.Linear(dim * 2, dim, bias=out_bias) - - if coverage: - self.linear_cover = nn.Linear(1, dim, bias=False) - - def score(self, h_t, h_s): - """ - Args: - h_t (FloatTensor): sequence of queries ``(batch, tgt_len, dim)`` - h_s (FloatTensor): sequence of sources ``(batch, src_len, dim`` - - Returns: - FloatTensor: raw attention scores (unnormalized) for each src index - ``(batch, tgt_len, src_len)`` - """ - - # Check input sizes - src_batch, src_len, src_dim = h_s.size() - tgt_batch, tgt_len, tgt_dim = h_t.size() - aeq(src_batch, tgt_batch) - aeq(src_dim, tgt_dim) - aeq(self.dim, src_dim) - - if self.attn_type in ["general", "dot"]: - if self.attn_type == "general": - h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) - h_t_ = self.linear_in(h_t_) - h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) - h_s_ = h_s.transpose(1, 2) - # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) - return torch.bmm(h_t, h_s_) - else: - dim = self.dim - wq = self.linear_query(h_t.view(-1, dim)) - wq = wq.view(tgt_batch, tgt_len, 1, dim) - wq = wq.expand(tgt_batch, tgt_len, src_len, dim) - - uh = self.linear_context(h_s.contiguous().view(-1, dim)) - uh = uh.view(src_batch, 1, src_len, dim) - uh = uh.expand(src_batch, tgt_len, src_len, dim) - - # (batch, t_len, s_len, d) - wquh = torch.tanh(wq + uh) - - return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len) - - def forward(self, source, memory_bank, memory_lengths=None, coverage=None): - """ - - Args: - source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` - memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` - memory_lengths (LongTensor): the source context lengths ``(batch,)`` - coverage (FloatTensor): None (not supported yet) - - Returns: - (FloatTensor, FloatTensor): - - * Computed vector ``(tgt_len, batch, dim)`` - * Attention distribtutions for each query - ``(tgt_len, batch, src_len)`` - """ - - # one step input - if source.dim() == 2: - one_step = True - source = source.unsqueeze(1) - else: - one_step = False - - batch, source_l, dim = memory_bank.size() - batch_, target_l, dim_ = source.size() - aeq(batch, batch_) - aeq(dim, dim_) - aeq(self.dim, dim) - if coverage is not None: - batch_, source_l_ = coverage.size() - aeq(batch, batch_) - aeq(source_l, source_l_) - - if coverage is not None: - cover = coverage.view(-1).unsqueeze(1) - memory_bank += self.linear_cover(cover).view_as(memory_bank) - memory_bank = torch.tanh(memory_bank) - - # compute attention scores, as in Luong et al. - align = self.score(source, memory_bank) - - if memory_lengths is not None: - mask = sequence_mask(memory_lengths, max_len=align.size(-1)) - mask = mask.unsqueeze(1) # Make it broadcastable. - align.masked_fill_(~mask, -float('inf')) - - # Softmax or sparsemax to normalize attention weights - if self.attn_func == "softmax": - align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) - else: - align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) - align_vectors = align_vectors.view(batch, target_l, source_l) - - # each context vector c_t is the weighted average - # over all the source hidden states - c = torch.bmm(align_vectors, memory_bank) - - # concatenate - concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) - attn_h = self.linear_out(concat_c).view(batch, target_l, dim) - if self.attn_type in ["general", "dot"]: - attn_h = torch.tanh(attn_h) - - if one_step: - attn_h = attn_h.squeeze(1) - align_vectors = align_vectors.squeeze(1) - - # Check output sizes - batch_, dim_ = attn_h.size() - aeq(batch, batch_) - aeq(dim, dim_) - batch_, source_l_ = align_vectors.size() - aeq(batch, batch_) - aeq(source_l, source_l_) - - else: - attn_h = attn_h.transpose(0, 1).contiguous() - align_vectors = align_vectors.transpose(0, 1).contiguous() - # Check output sizes - target_l_, batch_, dim_ = attn_h.size() - aeq(target_l, target_l_) - aeq(batch, batch_) - aeq(dim, dim_) - target_l_, batch_, source_l_ = align_vectors.size() - aeq(target_l, target_l_) - aeq(batch, batch_) - aeq(source_l, source_l_) - - return attn_h, align_vectors diff --git a/onmt/modules/sparse_activations.py b/onmt/modules/sparse_activations.py deleted file mode 100644 index 7a5e8a75..00000000 --- a/onmt/modules/sparse_activations.py +++ /dev/null @@ -1,97 +0,0 @@ -""" -An implementation of sparsemax (Martins & Astudillo, 2016). See -:cite:`DBLP:journals/corr/MartinsA16` for detailed description. - -By Ben Peters and Vlad Niculae -""" - -import torch -from torch.autograd import Function -from torch.cuda.amp import custom_fwd, custom_bwd -import torch.nn as nn - - -def _make_ix_like(input, dim=0): - d = input.size(dim) - rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) - view = [1] * input.dim() - view[0] = -1 - return rho.view(view).transpose(0, dim) - - -def _threshold_and_support(input, dim=0): - """Sparsemax building block: compute the threshold - - Args: - input: any dimension - dim: dimension along which to apply the sparsemax - - Returns: - the threshold value - """ - - input_srt, _ = torch.sort(input, descending=True, dim=dim) - input_cumsum = input_srt.cumsum(dim) - 1 - rhos = _make_ix_like(input, dim) - support = rhos * input_srt > input_cumsum - - support_size = support.sum(dim=dim).unsqueeze(dim) - tau = input_cumsum.gather(dim, support_size - 1) - tau /= support_size.to(input.dtype) - return tau, support_size - - -class SparsemaxFunction(Function): - @staticmethod - @custom_fwd - def forward(ctx, input, dim=0): - """sparsemax: normalizing sparse transform (a la softmax) - - Parameters: - input (Tensor): any shape - dim: dimension along which to apply sparsemax - - Returns: - output (Tensor): same shape as input - """ - ctx.dim = dim - max_val, _ = input.max(dim=dim, keepdim=True) - input -= max_val # same numerical stability trick as for softmax - tau, supp_size = _threshold_and_support(input, dim=dim) - output = torch.clamp(input - tau, min=0) - ctx.save_for_backward(supp_size, output) - return output - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - supp_size, output = ctx.saved_tensors - dim = ctx.dim - grad_input = grad_output.clone() - grad_input[output == 0] = 0 - - v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() - v_hat = v_hat.unsqueeze(dim) - grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) - return grad_input, None - - -sparsemax = SparsemaxFunction.apply - - -class Sparsemax(nn.Module): - def __init__(self, dim=0): - self.dim = dim - super(Sparsemax, self).__init__() - - def forward(self, input): - return sparsemax(input, self.dim) - - -class LogSparsemax(nn.Module): - def __init__(self, dim=0): - self.dim = dim - super(LogSparsemax, self).__init__() - - def forward(self, input): - return torch.log(sparsemax(input, self.dim)) diff --git a/onmt/modules/sparse_losses.py b/onmt/modules/sparse_losses.py deleted file mode 100644 index 3e67c885..00000000 --- a/onmt/modules/sparse_losses.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -import torch.nn as nn -from torch.autograd import Function -from torch.cuda.amp import custom_fwd, custom_bwd -from onmt.modules.sparse_activations import _threshold_and_support -from onmt.utils.misc import aeq - - -class SparsemaxLossFunction(Function): - @staticmethod - @custom_fwd - def forward(ctx, input, target): - """ - input (FloatTensor): ``(n, num_classes)``. - target (LongTensor): ``(n,)``, the indices of the target classes - """ - input_batch, classes = input.size() - target_batch = target.size(0) - aeq(input_batch, target_batch) - - z_k = input.gather(1, target.unsqueeze(1)).squeeze() - tau_z, support_size = _threshold_and_support(input, dim=1) - support = input > tau_z - x = torch.where(support, input**2 - tau_z**2, torch.tensor(0.0, device=input.device)).sum(dim=1) - ctx.save_for_backward(input, target, tau_z) - # clamping necessary because of numerical errors: loss should be lower - # bounded by zero, but negative values near zero are possible without - # the clamp - return torch.clamp(x / 2 - z_k + 0.5, min=0.0) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - input, target, tau_z = ctx.saved_tensors - sparsemax_out = torch.clamp(input - tau_z, min=0) - delta = torch.zeros_like(sparsemax_out) - delta.scatter_(1, target.unsqueeze(1), 1) - return sparsemax_out - delta, None - - -sparsemax_loss = SparsemaxLossFunction.apply - - -class SparsemaxLoss(nn.Module): - """ - An implementation of sparsemax loss, first proposed in - :cite:`DBLP:journals/corr/MartinsA16`. If using - a sparse output layer, it is not possible to use negative log likelihood - because the loss is infinite in the case the target is assigned zero - probability. Inputs to SparsemaxLoss are arbitrary dense real-valued - vectors (like in nn.CrossEntropyLoss), not probability vectors (like in - nn.NLLLoss). - """ - - def __init__(self, weight=None, ignore_index=-100, reduction='elementwise_mean'): - assert reduction in ['elementwise_mean', 'sum', 'none'] - self.reduction = reduction - self.weight = weight - self.ignore_index = ignore_index - super(SparsemaxLoss, self).__init__() - - def forward(self, input, target): - loss = sparsemax_loss(input, target) - if self.ignore_index >= 0: - ignored_positions = target == self.ignore_index - size = float((target.size(0) - ignored_positions.sum()).item()) - loss.masked_fill_(ignored_positions, 0.0) - else: - size = float(target.size(0)) - if self.reduction == 'sum': - loss = loss.sum() - elif self.reduction == 'elementwise_mean': - loss = loss.sum() / size - return loss diff --git a/onmt/modules/stable_embeddings.py b/onmt/modules/stable_embeddings.py deleted file mode 100644 index 8111cc76..00000000 --- a/onmt/modules/stable_embeddings.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -import torch - -from typing import Optional - -from torch import Tensor -import torch.nn.functional as F - -# from bitsandbytes.optim import GlobalOptimManager - - -class StableEmbedding(torch.nn.Embedding): - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False, - _weight: Optional[Tensor] = None, - ) -> None: - super(StableEmbedding, self).__init__( - num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight - ) - self.norm = torch.nn.LayerNorm(embedding_dim, eps=1e-6) - # GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) - - def reset_parameters(self) -> None: - torch.nn.init.xavier_uniform_(self.weight) - self._fill_padding_idx_with_zero() - - ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding - to make the Layer compatible with Pytorch < 1.9. - This means that if this changes in future PyTorch releases this need to change too - which is cumbersome. However, with this we can ensure compatibility with previous - PyTorch releases. - ''' - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - emb = F.embedding( - input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse - ) - - return self.norm(emb) - - -class Embedding(torch.nn.Embedding): - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False, - _weight: Optional[Tensor] = None, - ) -> None: - super(Embedding, self).__init__( - num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight - ) - # GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) - - def reset_parameters(self) -> None: - torch.nn.init.xavier_uniform_(self.weight) - self._fill_padding_idx_with_zero() - - ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding - to make the Layer compatible with Pytorch < 1.9. - This means that if this changes in future PyTorch releases this need to change too - which is cumbersome. However, with this we can ensure compatibility with previous - PyTorch releases. - ''' - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - emb = F.embedding( - input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse - ) - - return emb diff --git a/onmt/modules/structured_attention.py b/onmt/modules/structured_attention.py deleted file mode 100644 index 59206b61..00000000 --- a/onmt/modules/structured_attention.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch.nn as nn -import torch -import torch.cuda - - -class MatrixTree(nn.Module): - """Implementation of the matrix-tree theorem for computing marginals - of non-projective dependency parsing. This attention layer is used - in the paper "Learning Structured Text Representations" - :cite:`DBLP:journals/corr/LiuL17d`. - """ - - def __init__(self, eps=1e-5): - self.eps = eps - super(MatrixTree, self).__init__() - - def forward(self, input): - laplacian = input.exp() + self.eps - output = input.clone() - for b in range(input.size(0)): - lap = laplacian[b].masked_fill(torch.eye(input.size(1), device=input.device).ne(0), 0) - lap = -lap + torch.diag(lap.sum(0)) - # store roots on diagonal - lap[0] = input[b].diag().exp() - inv_laplacian = lap.inverse() - - factor = inv_laplacian.diag().unsqueeze(1).expand_as(input[b]).transpose(0, 1) - term1 = input[b].exp().mul(factor).clone() - term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone() - term1[:, 0] = 0 - term2[0] = 0 - output[b] = term1 - term2 - roots_output = input[b].diag().exp().mul(inv_laplacian.transpose(0, 1)[0]) - output[b] = output[b] + torch.diag(roots_output) - return output diff --git a/onmt/modules/weight_norm.py b/onmt/modules/weight_norm.py deleted file mode 100644 index 723a7d74..00000000 --- a/onmt/modules/weight_norm.py +++ /dev/null @@ -1,224 +0,0 @@ -""" Weights normalization modules """ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Parameter - - -def get_var_maybe_avg(namespace, var_name, training, polyak_decay): - """utility for retrieving polyak averaged params - Update average - """ - v = getattr(namespace, var_name) - v_avg = getattr(namespace, var_name + '_avg') - v_avg -= (1 - polyak_decay) * (v_avg - v.data) - - if training: - return v - else: - return v_avg - - -def get_vars_maybe_avg(namespace, var_names, training, polyak_decay): - """utility for retrieving polyak averaged params""" - vars = [] - for vn in var_names: - vars.append(get_var_maybe_avg(namespace, vn, training, polyak_decay)) - return vars - - -class WeightNormLinear(nn.Linear): - """ - Implementation of "Weight Normalization: A Simple Reparameterization - to Accelerate Training of Deep Neural Networks" - :cite:`DBLP:journals/corr/SalimansK16` - - As a reparameterization method, weight normalization is same - as BatchNormalization, but it doesn't depend on minibatch. - - NOTE: This is used nowhere in the code at this stage - Vincent Nguyen 05/18/2018 - """ - - def __init__(self, in_features, out_features, init_scale=1.0, polyak_decay=0.9995): - super(WeightNormLinear, self).__init__(in_features, out_features, bias=True) - - self.V = self.weight - self.g = Parameter(torch.Tensor(out_features)) - self.b = self.bias - - self.register_buffer('V_avg', torch.zeros(out_features, in_features)) - self.register_buffer('g_avg', torch.zeros(out_features)) - self.register_buffer('b_avg', torch.zeros(out_features)) - - self.init_scale = init_scale - self.polyak_decay = polyak_decay - self.reset_parameters() - - def reset_parameters(self): - return - - def forward(self, x, init=False): - if init is True: - # out_features * in_features - self.V.data.copy_(torch.randn(self.V.data.size()).type_as(self.V.data) * 0.05) - # norm is out_features * 1 - v_norm = self.V.data / self.V.data.norm(2, 1).expand_as(self.V.data) - # batch_size * out_features - x_init = F.linear(x, v_norm).data - # out_features - m_init, v_init = x_init.mean(0).squeeze(0), x_init.var(0).squeeze(0) - # out_features - scale_init = self.init_scale / torch.sqrt(v_init + 1e-10) - self.g.data.copy_(scale_init) - self.b.data.copy_(-m_init * scale_init) - x_init = scale_init.view(1, -1).expand_as(x_init) * (x_init - m_init.view(1, -1).expand_as(x_init)) - self.V_avg.copy_(self.V.data) - self.g_avg.copy_(self.g.data) - self.b_avg.copy_(self.b.data) - return x_init - else: - v, g, b = get_vars_maybe_avg(self, ['V', 'g', 'b'], self.training, polyak_decay=self.polyak_decay) - # batch_size * out_features - x = F.linear(x, v) - scalar = g / torch.norm(v, 2, 1).squeeze(1) - x = scalar.view(1, -1).expand_as(x) * x + b.view(1, -1).expand_as(x) - return x - - -class WeightNormConv2d(nn.Conv2d): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - init_scale=1.0, - polyak_decay=0.9995, - ): - super(WeightNormConv2d, self).__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, groups - ) - - self.V = self.weight - self.g = Parameter(torch.Tensor(out_channels)) - self.b = self.bias - - self.register_buffer('V_avg', torch.zeros(self.V.size())) - self.register_buffer('g_avg', torch.zeros(out_channels)) - self.register_buffer('b_avg', torch.zeros(out_channels)) - - self.init_scale = init_scale - self.polyak_decay = polyak_decay - self.reset_parameters() - - def reset_parameters(self): - return - - def forward(self, x, init=False): - if init is True: - # out_channels, in_channels // groups, * kernel_size - self.V.data.copy_(torch.randn(self.V.data.size()).type_as(self.V.data) * 0.05) - v_norm = self.V.data / self.V.data.view(self.out_channels, -1).norm(2, 1).view( - self.out_channels, *([1] * (len(self.kernel_size) + 1)) - ).expand_as(self.V.data) - x_init = F.conv2d(x, v_norm, None, self.stride, self.padding, self.dilation, self.groups).data - t_x_init = x_init.transpose(0, 1).contiguous().view(self.out_channels, -1) - m_init, v_init = t_x_init.mean(1).squeeze(1), t_x_init.var(1).squeeze(1) - # out_features - scale_init = self.init_scale / torch.sqrt(v_init + 1e-10) - self.g.data.copy_(scale_init) - self.b.data.copy_(-m_init * scale_init) - scale_init_shape = scale_init.view(1, self.out_channels, *([1] * (len(x_init.size()) - 2))) - m_init_shape = m_init.view(1, self.out_channels, *([1] * (len(x_init.size()) - 2))) - x_init = scale_init_shape.expand_as(x_init) * (x_init - m_init_shape.expand_as(x_init)) - self.V_avg.copy_(self.V.data) - self.g_avg.copy_(self.g.data) - self.b_avg.copy_(self.b.data) - return x_init - else: - v, g, b = get_vars_maybe_avg(self, ['V', 'g', 'b'], self.training, polyak_decay=self.polyak_decay) - - scalar = torch.norm(v.view(self.out_channels, -1), 2, 1) - if len(scalar.size()) == 2: - scalar = g / scalar.squeeze(1) - else: - scalar = g / scalar - - w = scalar.view(self.out_channels, *([1] * (len(v.size()) - 1))).expand_as(v) * v - - x = F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups) - return x - - -# This is used nowhere in the code at the moment (Vincent Nguyen 05/18/2018) - - -class WeightNormConvTranspose2d(nn.ConvTranspose2d): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - output_padding=0, - groups=1, - init_scale=1.0, - polyak_decay=0.9995, - ): - super(WeightNormConvTranspose2d, self).__init__( - in_channels, out_channels, kernel_size, stride, padding, output_padding, groups - ) - # in_channels, out_channels, *kernel_size - self.V = self.weight - self.g = Parameter(torch.Tensor(out_channels)) - self.b = self.bias - - self.register_buffer('V_avg', torch.zeros(self.V.size())) - self.register_buffer('g_avg', torch.zeros(out_channels)) - self.register_buffer('b_avg', torch.zeros(out_channels)) - - self.init_scale = init_scale - self.polyak_decay = polyak_decay - self.reset_parameters() - - def reset_parameters(self): - return - - def forward(self, x, init=False): - if init is True: - # in_channels, out_channels, *kernel_size - self.V.data.copy_(torch.randn(self.V.data.size()).type_as(self.V.data) * 0.05) - v_norm = self.V.data / self.V.data.transpose(0, 1).contiguous().view(self.out_channels, -1).norm(2, 1).view( - self.in_channels, self.out_channels, *([1] * len(self.kernel_size)) - ).expand_as(self.V.data) - x_init = F.conv_transpose2d( - x, v_norm, None, self.stride, self.padding, self.output_padding, self.groups - ).data - # self.out_channels, 1 - t_x_init = x_init.tranpose(0, 1).contiguous().view(self.out_channels, -1) - # out_features - m_init, v_init = t_x_init.mean(1).squeeze(1), t_x_init.var(1).squeeze(1) - # out_features - scale_init = self.init_scale / torch.sqrt(v_init + 1e-10) - self.g.data.copy_(scale_init) - self.b.data.copy_(-m_init * scale_init) - scale_init_shape = scale_init.view(1, self.out_channels, *([1] * (len(x_init.size()) - 2))) - m_init_shape = m_init.view(1, self.out_channels, *([1] * (len(x_init.size()) - 2))) - - x_init = scale_init_shape.expand_as(x_init) * (x_init - m_init_shape.expand_as(x_init)) - self.V_avg.copy_(self.V.data) - self.g_avg.copy_(self.g.data) - self.b_avg.copy_(self.b.data) - return x_init - else: - v, g, b = get_vars_maybe_avg(self, ['V', 'g', 'b'], self.training, polyak_decay=self.polyak_decay) - scalar = g / torch.norm(v.transpose(0, 1).contiguous().view(self.out_channels, -1), 2, 1).squeeze(1) - w = scalar.view(self.in_channels, self.out_channels, *([1] * (len(v.size()) - 2))).expand_as(v) * v - - x = F.conv_transpose2d(x, w, b, self.stride, self.padding, self.output_padding, self.groups) - return x diff --git a/onmt/tests/test_attention.py b/onmt/tests/test_attention.py deleted file mode 100644 index acffac3e..00000000 --- a/onmt/tests/test_attention.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Here come the tests for attention types and their compatibility -""" -import unittest -import torch -from torch.autograd import Variable - -import onmt - - -class TestAttention(unittest.TestCase): - def test_masked_global_attention(self): - - source_lengths = torch.IntTensor([7, 3, 5, 2]) - # illegal_weights_mask = torch.ByteTensor([ - # [0, 0, 0, 0, 0, 0, 0], - # [0, 0, 0, 1, 1, 1, 1], - # [0, 0, 0, 0, 0, 1, 1], - # [0, 0, 1, 1, 1, 1, 1]]) - - batch_size = source_lengths.size(0) - dim = 20 - - memory_bank = Variable(torch.randn(batch_size, source_lengths.max(), dim)) - hidden = Variable(torch.randn(batch_size, dim)) - - attn = onmt.modules.GlobalAttention(dim) - - _, alignments = attn(hidden, memory_bank, memory_lengths=source_lengths) - # TODO: fix for pytorch 0.3 - # illegal_weights = alignments.masked_select(illegal_weights_mask) - - # self.assertEqual(0.0, illegal_weights.data.sum()) diff --git a/onmt/tests/test_copy_generator.py b/onmt/tests/test_copy_generator.py deleted file mode 100644 index 4b3291fa..00000000 --- a/onmt/tests/test_copy_generator.py +++ /dev/null @@ -1,113 +0,0 @@ -import unittest -from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss - -import itertools -from copy import deepcopy - -import torch -from torch.nn.functional import softmax - -from onmt.tests.utils_for_tests import product_dict - - -class TestCopyGenerator(unittest.TestCase): - INIT_CASES = list( - product_dict( - input_size=[172], - output_size=[319], - pad_idx=[0, 39], - ) - ) - PARAMS = list(product_dict(batch_size=[1, 14], max_seq_len=[23], tgt_max_len=[50], n_extra_words=[107])) - - @classmethod - def dummy_inputs(cls, params, init_case): - hidden = torch.randn((params["batch_size"] * params["tgt_max_len"], init_case["input_size"])) - attn = torch.randn((params["batch_size"] * params["tgt_max_len"], params["max_seq_len"])) - src_map = torch.randn((params["max_seq_len"], params["batch_size"], params["n_extra_words"])) - return hidden, attn, src_map - - @classmethod - def expected_shape(cls, params, init_case): - return params["tgt_max_len"] * params["batch_size"], init_case["output_size"] + params["n_extra_words"] - - def test_copy_gen_forward_shape(self): - for params, init_case in itertools.product(self.PARAMS, self.INIT_CASES): - cgen = CopyGenerator(**init_case) - dummy_in = self.dummy_inputs(params, init_case) - res = cgen(*dummy_in) - expected_shape = self.expected_shape(params, init_case) - self.assertEqual(res.shape, expected_shape, init_case.__str__()) - - def test_copy_gen_outp_has_no_prob_of_pad(self): - for params, init_case in itertools.product(self.PARAMS, self.INIT_CASES): - cgen = CopyGenerator(**init_case) - dummy_in = self.dummy_inputs(params, init_case) - res = cgen(*dummy_in) - self.assertTrue(res[:, init_case["pad_idx"]].allclose(torch.tensor(0.0))) - - def test_copy_gen_trainable_params_update(self): - for params, init_case in itertools.product(self.PARAMS, self.INIT_CASES): - cgen = CopyGenerator(**init_case) - trainable_params = {n: p for n, p in cgen.named_parameters() if p.requires_grad} - assert len(trainable_params) > 0 # sanity check - old_weights = deepcopy(trainable_params) - dummy_in = self.dummy_inputs(params, init_case) - res = cgen(*dummy_in) - pretend_loss = res.sum() - pretend_loss.backward() - dummy_optim = torch.optim.SGD(trainable_params.values(), 1) - dummy_optim.step() - for param_name in old_weights.keys(): - self.assertTrue( - trainable_params[param_name].ne(old_weights[param_name]).any(), - param_name + " " + init_case.__str__(), - ) - - -class TestCopyGeneratorLoss(unittest.TestCase): - INIT_CASES = list( - product_dict(vocab_size=[172], unk_index=[0, 39], ignore_index=[1, 17], force_copy=[True, False]) # pad idx - ) - PARAMS = list(product_dict(batch_size=[1, 14], tgt_max_len=[50], n_extra_words=[107])) - - @classmethod - def dummy_inputs(cls, params, init_case): - n_unique_src_words = 13 - scores = torch.randn( - (params["batch_size"] * params["tgt_max_len"], init_case["vocab_size"] + n_unique_src_words) - ) - scores = softmax(scores, dim=1) - align = torch.randint(0, n_unique_src_words, (params["batch_size"] * params["tgt_max_len"],)) - target = torch.randint(0, init_case["vocab_size"], (params["batch_size"] * params["tgt_max_len"],)) - target[0] = init_case["unk_index"] - target[1] = init_case["ignore_index"] - return scores, align, target - - @classmethod - def expected_shape(cls, params, init_case): - return (params["batch_size"] * params["tgt_max_len"],) - - def test_copy_loss_forward_shape(self): - for params, init_case in itertools.product(self.PARAMS, self.INIT_CASES): - loss = CopyGeneratorLoss(**init_case) - dummy_in = self.dummy_inputs(params, init_case) - res = loss(*dummy_in) - expected_shape = self.expected_shape(params, init_case) - self.assertEqual(res.shape, expected_shape, init_case.__str__()) - - def test_copy_loss_ignore_index_is_ignored(self): - for params, init_case in itertools.product(self.PARAMS, self.INIT_CASES): - loss = CopyGeneratorLoss(**init_case) - scores, align, target = self.dummy_inputs(params, init_case) - res = loss(scores, align, target) - should_be_ignored = (target == init_case["ignore_index"]).nonzero(as_tuple=False) - assert len(should_be_ignored) > 0 # otherwise not testing anything - self.assertTrue(res[should_be_ignored].allclose(torch.tensor(0.0))) - - def test_copy_loss_output_range_is_positive(self): - for params, init_case in itertools.product(self.PARAMS, self.INIT_CASES): - loss = CopyGeneratorLoss(**init_case) - dummy_in = self.dummy_inputs(params, init_case) - res = loss(*dummy_in) - self.assertTrue((res >= 0).all()) diff --git a/onmt/tests/test_structured_attention.py b/onmt/tests/test_structured_attention.py deleted file mode 100644 index 543be5b8..00000000 --- a/onmt/tests/test_structured_attention.py +++ /dev/null @@ -1,12 +0,0 @@ -import unittest -from onmt.modules.structured_attention import MatrixTree - -import torch - - -class TestStructuredAttention(unittest.TestCase): - def test_matrix_tree_marg_pdfs_sum_to_1(self): - dtree = MatrixTree() - q = torch.rand(1, 5, 5) - marg = dtree.forward(q) - self.assertTrue(marg.sum(1).allclose(torch.tensor(1.0))) diff --git a/onmt/translate/__init__.py b/onmt/translate/__init__.py deleted file mode 100644 index 21901092..00000000 --- a/onmt/translate/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -""" Modules for translation """ -from onmt.translate.translator import Translator, GeneratorLM -from onmt.translate.translation import Translation, TranslationBuilder -from onmt.translate.beam_search import BeamSearch, GNMTGlobalScorer -from onmt.translate.beam_search import BeamSearchLM -from onmt.translate.decode_strategy import DecodeStrategy -from onmt.translate.greedy_search import GreedySearch, GreedySearchLM -from onmt.translate.penalties import PenaltyBuilder -from onmt.translate.translation_server import TranslationServer, ServerModelError - -__all__ = [ - 'Translator', - 'Translation', - 'BeamSearch', - 'GNMTGlobalScorer', - 'TranslationBuilder', - 'PenaltyBuilder', - 'TranslationServer', - 'ServerModelError', - "DecodeStrategy", - "GreedySearch", - "GreedySearchLM", - "BeamSearchLM", - "GeneratorLM", -] diff --git a/onmt/utils/__init__.py b/onmt/utils/__init__.py deleted file mode 100644 index e835836e..00000000 --- a/onmt/utils/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Module defining various utilities.""" -from onmt.utils.misc import split_corpus, aeq, use_gpu, set_random_seed -from onmt.utils.alignment import make_batch_align_matrix -from onmt.utils.report_manager import ReportMgr, build_report_manager -from onmt.utils.statistics import Statistics -from onmt.utils.optimizers import MultipleOptimizer, Optimizer, AdaFactorFairSeq -from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts - -__all__ = [ - "split_corpus", - "aeq", - "use_gpu", - "set_random_seed", - "ReportMgr", - "build_report_manager", - "Statistics", - "MultipleOptimizer", - "Optimizer", - "AdaFactorFairSeq", - "EarlyStopping", - "scorers_from_opts", - "make_batch_align_matrix", -] diff --git a/onmt/utils/cnn_factory.py b/onmt/utils/cnn_factory.py deleted file mode 100644 index 68430426..00000000 --- a/onmt/utils/cnn_factory.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -Implementation of "Convolutional Sequence to Sequence Learning" -""" -import torch -import torch.nn as nn -import torch.nn.init as init - -import onmt.modules - -SCALE_WEIGHT = 0.5**0.5 - - -def shape_transform(x): - """Tranform the size of the tensors to fit for conv input.""" - return torch.unsqueeze(torch.transpose(x, 1, 2), 3) - - -class GatedConv(nn.Module): - """Gated convolution for CNN class""" - - def __init__(self, input_size, width=3, dropout=0.2, nopad=False): - super(GatedConv, self).__init__() - self.conv = onmt.modules.WeightNormConv2d( - input_size, 2 * input_size, kernel_size=(width, 1), stride=(1, 1), padding=(width // 2 * (1 - nopad), 0) - ) - init.xavier_uniform_(self.conv.weight, gain=(4 * (1 - dropout)) ** 0.5) - self.dropout = nn.Dropout(dropout) - - def forward(self, x_var): - x_var = self.dropout(x_var) - x_var = self.conv(x_var) - out, gate = x_var.split(int(x_var.size(1) / 2), 1) - out = out * torch.sigmoid(gate) - return out - - -class StackedCNN(nn.Module): - """Stacked CNN class""" - - def __init__(self, num_layers, input_size, cnn_kernel_width=3, dropout=0.2): - super(StackedCNN, self).__init__() - self.dropout = dropout - self.num_layers = num_layers - self.layers = nn.ModuleList() - for _ in range(num_layers): - self.layers.append(GatedConv(input_size, cnn_kernel_width, dropout)) - - def forward(self, x): - for conv in self.layers: - x = x + conv(x) - x *= SCALE_WEIGHT - return x diff --git a/onmt/utils/rnn_factory.py b/onmt/utils/rnn_factory.py deleted file mode 100644 index f35c48e3..00000000 --- a/onmt/utils/rnn_factory.py +++ /dev/null @@ -1,17 +0,0 @@ -""" - RNN tools -""" -import torch.nn as nn -import onmt.models - - -def rnn_factory(rnn_type, **kwargs): - """rnn factory, Use pytorch version when available.""" - no_pack_padded_seq = False - if rnn_type == "SRU": - # SRU doesn't support PackedSequence. - no_pack_padded_seq = True - rnn = onmt.models.sru.SRU(**kwargs) - else: - rnn = getattr(nn, rnn_type)(**kwargs) - return rnn, no_pack_padded_seq diff --git a/server.py b/server.py index 2e078ba6..43c54b1c 100644 --- a/server.py +++ b/server.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from onmt.bin.server import main +from mammoth.bin.server import main if __name__ == "__main__": diff --git a/setup.py b/setup.py index 555141bb..1dc4e9b0 100644 --- a/setup.py +++ b/setup.py @@ -7,11 +7,11 @@ long_description = f.read() setup( - name='OpenNMT-py', - description='A python implementation of OpenNMT', + name='mammoth', + description='Massively Multilingual Modular Open Translation @ Helsinki', long_description=long_description, long_description_content_type='text/markdown', - version='2.2.0', + version='0.1', packages=find_packages(), project_urls={ "Documentation": "http://opennmt.net/OpenNMT-py/", @@ -35,12 +35,12 @@ ], entry_points={ "console_scripts": [ - "onmt_server=onmt.bin.server:main", - "onmt_train=onmt.bin.train:main", - "onmt_translate=onmt.bin.translate:main", - "onmt_release_model=onmt.bin.release_model:main", - "onmt_average_models=onmt.bin.average_models:main", - "onmt_build_vocab=onmt.bin.build_vocab:main", + # "onmt_server=mammoth.bin.server:main", + "mammoth_train=mammoth.bin.train:main", + "mammoth_translate=mammoth.bin.translate:main", + # "onmt_release_model=mammoth.bin.release_model:main", + # "onmt_average_models=mammoth.bin.average_models:main", + # "onmt_build_vocab=mammoth.bin.build_vocab:main", ], }, ) diff --git a/test_communication/test.py b/test_communication/test.py index 4895fabd..932a1aed 100644 --- a/test_communication/test.py +++ b/test_communication/test.py @@ -7,9 +7,9 @@ import timeout_decorator -import onmt -from onmt.bin.train import train -from onmt.utils.parse import ArgumentParser +import mammoth +from mammoth.bin.train import train +from mammoth.utils.parse import ArgumentParser import torch.multiprocessing as mp @@ -20,7 +20,7 @@ class TestTraining(TestCase): @classmethod def setUpClass(cls) -> None: cls.parser = ArgumentParser(description="train.py") - onmt.opts.train_opts(cls.parser) + mammoth.opts.train_opts(cls.parser) # clear output folders for folder in ["models", "tensorboard"]: if os.path.exists(folder): @@ -35,13 +35,13 @@ def tearDown(self) -> None: child_process.kill() @staticmethod - def _get_model_components(opt) -> List[str]: + def _get_model_components(opts) -> List[str]: # N.B: These components are only valid for very vanilla language-specific xcoder with fully shared AB models - components_enc = [f"encoder_0_{src_lang}" for src_lang in ast.literal_eval(opt.src_vocab).keys()] - components_dec = [f"encoder_0_{tgt_lang}" for tgt_lang in ast.literal_eval(opt.tgt_vocab).keys()] - components_gen = [f"generator_{tgt_lang}" for tgt_lang in ast.literal_eval(opt.tgt_vocab).keys()] - components_src_emb = [f"src_embeddings_{src_lang}" for src_lang in ast.literal_eval(opt.src_vocab).keys()] - components_tgt_emb = [f"tgt_embeddings_{tgt_lang}" for tgt_lang in ast.literal_eval(opt.tgt_vocab).keys()] + components_enc = [f"encoder_0_{src_lang}" for src_lang in ast.literal_eval(opts.src_vocab).keys()] + components_dec = [f"encoder_0_{tgt_lang}" for tgt_lang in ast.literal_eval(opts.tgt_vocab).keys()] + components_gen = [f"generator_{tgt_lang}" for tgt_lang in ast.literal_eval(opts.tgt_vocab).keys()] + components_src_emb = [f"src_embeddings_{src_lang}" for src_lang in ast.literal_eval(opts.src_vocab).keys()] + components_tgt_emb = [f"tgt_embeddings_{tgt_lang}" for tgt_lang in ast.literal_eval(opts.tgt_vocab).keys()] return [ "frame", "attention_bridge", @@ -55,7 +55,7 @@ def _get_model_components(opt) -> List[str]: @timeout_decorator.timeout(60) def test_training_1gpu_4pairs(self): out_model_prefix = "wmt_1gpu_4pairs" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -72,14 +72,14 @@ def test_training_1gpu_4pairs(self): "0:0", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -91,7 +91,7 @@ def test_training_1gpu_4pairs(self): @timeout_decorator.timeout(60) def test_training_1gpu_4pairs_ab_lin(self): out_model_prefix = "wmt_1gpu_4pairs_lin" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -114,14 +114,14 @@ def test_training_1gpu_4pairs_ab_lin(self): "0:0", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -133,7 +133,7 @@ def test_training_1gpu_4pairs_ab_lin(self): @timeout_decorator.timeout(60) def test_training_1gpu_4pairs_ab_ff(self): out_model_prefix = "wmt_1gpu_4pairs_ff" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -154,14 +154,14 @@ def test_training_1gpu_4pairs_ab_ff(self): "0:0", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -173,7 +173,7 @@ def test_training_1gpu_4pairs_ab_ff(self): @timeout_decorator.timeout(60) def test_training_1gpu_4pairs_ab_tf(self): out_model_prefix = "wmt_1gpu_4pairs_tf" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -194,14 +194,14 @@ def test_training_1gpu_4pairs_ab_tf(self): "0:0", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -213,7 +213,7 @@ def test_training_1gpu_4pairs_ab_tf(self): @timeout_decorator.timeout(60) def test_training_1gpu_4pairs_ab_simple(self): out_model_prefix = "wmt_1gpu_4pairs_simple" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -234,14 +234,14 @@ def test_training_1gpu_4pairs_ab_simple(self): "0:0", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -253,7 +253,7 @@ def test_training_1gpu_4pairs_ab_simple(self): @timeout_decorator.timeout(60) def test_training_1gpu_4pairs_ab_perceiver(self): out_model_prefix = "wmt_1gpu_4pairs_perceiver" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -274,14 +274,14 @@ def test_training_1gpu_4pairs_ab_perceiver(self): "0:0", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -293,7 +293,7 @@ def test_training_1gpu_4pairs_ab_perceiver(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs(self): out_model_prefix = "wmt_2gpus_4pairs" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -311,14 +311,14 @@ def test_training_2gpus_4pairs(self): "0:1", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -330,7 +330,7 @@ def test_training_2gpus_4pairs(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs_ab_lin(self): out_model_prefix = "wmt_2gpus_4pairs_lin" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -354,14 +354,14 @@ def test_training_2gpus_4pairs_ab_lin(self): "0:1", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -373,7 +373,7 @@ def test_training_2gpus_4pairs_ab_lin(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs_ab_ff(self): out_model_prefix = "wmt_2gpus_4pairs_ff" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -395,14 +395,14 @@ def test_training_2gpus_4pairs_ab_ff(self): "0:1", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -414,7 +414,7 @@ def test_training_2gpus_4pairs_ab_ff(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs_ab_tf(self): out_model_prefix = "wmt_2gpus_4pairs_tf" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -436,14 +436,14 @@ def test_training_2gpus_4pairs_ab_tf(self): "0:1", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -455,7 +455,7 @@ def test_training_2gpus_4pairs_ab_tf(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs_ab_simple(self): out_model_prefix = "wmt_2gpus_4pairs_simple" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -479,14 +479,14 @@ def test_training_2gpus_4pairs_ab_simple(self): "0:1", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -498,7 +498,7 @@ def test_training_2gpus_4pairs_ab_simple(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs_ab_perceiver(self): out_model_prefix = "wmt_2gpus_4pairs_perceiver" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -520,14 +520,14 @@ def test_training_2gpus_4pairs_ab_perceiver(self): "0:1", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -539,7 +539,7 @@ def test_training_2gpus_4pairs_ab_perceiver(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs_crossed(self): out_model_prefix = "wmt_2gpus_4pairs_crossed" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -557,14 +557,14 @@ def test_training_2gpus_4pairs_crossed(self): "0:0", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -576,7 +576,7 @@ def test_training_2gpus_4pairs_crossed(self): @timeout_decorator.timeout(60) def test_training_4gpus_4pairs(self): out_model_prefix = "wmt_4gpus_4pairs" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -596,14 +596,14 @@ def test_training_4gpus_4pairs(self): "0:3", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -615,7 +615,7 @@ def test_training_4gpus_4pairs(self): @timeout_decorator.timeout(120) def test_training_3gpus_12pairs(self): out_model_prefix = "wmt_3gpus_12pairs" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_12pairs.yml", @@ -642,14 +642,14 @@ def test_training_3gpus_12pairs(self): "0:2", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -661,7 +661,7 @@ def test_training_3gpus_12pairs(self): @timeout_decorator.timeout(120) def test_training_3gpus_21pairs(self): out_model_prefix = "wmt_3gpus_21pairs" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_21pairs.yml", @@ -697,14 +697,14 @@ def test_training_3gpus_21pairs(self): "0:2", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -716,7 +716,7 @@ def test_training_3gpus_21pairs(self): @timeout_decorator.timeout(120) def test_training_4gpus_12pairs(self): out_model_prefix = "wmt_4gpus_12pairs" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_12pairs.yml", @@ -744,14 +744,14 @@ def test_training_4gpus_12pairs(self): "0:3", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -763,7 +763,7 @@ def test_training_4gpus_12pairs(self): @timeout_decorator.timeout(120) def test_training_4gpus_24pairs(self): out_model_prefix = "wmt_4gpus_24pairs" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_24pairs.yml", @@ -803,14 +803,14 @@ def test_training_4gpus_24pairs(self): "0:3", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -822,7 +822,7 @@ def test_training_4gpus_24pairs(self): @timeout_decorator.timeout(120) def test_training_1gpu_tensorboard(self): out_model_prefix = "wmt_1gpu_tb" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -843,14 +843,14 @@ def test_training_1gpu_tensorboard(self): "0:0", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -868,7 +868,7 @@ def test_training_1gpu_tensorboard(self): @timeout_decorator.timeout(120) def test_training_2gpus_tensorboard(self): out_model_prefix = "wmt_2gpus_tb" - opt, _ = self.parser.parse_known_args( + opts, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -890,14 +890,14 @@ def test_training_2gpus_tensorboard(self): "0:1", ] ) - components = self._get_model_components(opt) + components = self._get_model_components(opts) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opt) + train(opts) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -919,14 +919,14 @@ def test_training_2gpus_tensorboard(self): # @classmethod # def setUpClass(cls) -> None: # cls.parser = ArgumentParser(description="translate.py") -# onmt.opts.config_opts(cls.parser) -# onmt.opts.translate_opts(cls.parser) -# onmt.opts.build_bilingual_model(cls.parser) +# mammoth.opts.config_opts(cls.parser) +# mammoth.opts.translate_opts(cls.parser) +# mammoth.opts.build_bilingual_model(cls.parser) # # def test_translate(self): # # TODO: train model instead of loading one the one used now, # # remove all absolute paths, add test data in the repo -# opt, _ = self.parser.parse_known_args( +# opts, _ = self.parser.parse_known_args( # [ # "-gpu", # "0", @@ -945,4 +945,4 @@ def test_training_2gpus_tensorboard(self): # "-use_attention_bridge", # ] # ) -# translate(opt) +# translate(opts) diff --git a/tools/attention_bank.py b/tools/attention_bank.py index d244a9ff..38bdb866 100644 --- a/tools/attention_bank.py +++ b/tools/attention_bank.py @@ -7,12 +7,12 @@ import torch import tqdm -from onmt.inputters.dataset import ParallelCorpus -from onmt.inputters.dataloader import build_dataloader -from onmt.model_builder import load_test_multitask_model -from onmt.opts import build_bilingual_model, _add_dynamic_transform_opts -from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe -from onmt.utils.parse import ArgumentParser +from mammoth.inputters.dataset import ParallelCorpus +from mammoth.inputters.dataloader import build_dataloader +from mammoth.model_builder import load_test_multitask_model +from mammoth.opts import build_bilingual_model, _add_dynamic_transform_opts +from mammoth.transforms import get_transforms_cls, make_transforms, TransformPipe +from mammoth.utils.parse import ArgumentParser def get_opts(): @@ -78,7 +78,7 @@ def _extract(sentences_file, model, vocabs_dict, transforms, enc_id, batch_size= yield (memory_bank, src_lengths) -def extract(opts, vocabs_dict, model, model_opt, transforms): +def extract(opts, vocabs_dict, model, model_opts, transforms): """Compute representations drawn from the encoder and save them to file.""" sentence_reps = [] for src, src_length in _extract( @@ -95,7 +95,7 @@ def extract(opts, vocabs_dict, model, model_opt, transforms): torch.save(sentence_reps, opts.dump_file) -def estimate(opts, vocabs_dict, model, model_opt, transforms): +def estimate(opts, vocabs_dict, model, model_opts, transforms): """Estimate the matrix-variate distribution of representations drawn from the encoder.""" try: import sklearn.covariance @@ -134,7 +134,7 @@ def estimate(opts, vocabs_dict, model, model_opt, transforms): # return sampling_fn -def classify(opts, vocabs_dict, model, model_opt, transforms): +def classify(opts, vocabs_dict, model, model_opts, transforms): """Learn a simple SGD classifier using representations drawn from the encoder.""" try: import sklearn.linear_model @@ -224,7 +224,7 @@ def main(): # ArgumentParser.validate_translate_opts_dynamic(opts) opts.enc_id = opts.enc_id or opts.src_lang - vocabs_dict, model, model_opt = load_test_multitask_model(opts, opts.model) + vocabs_dict, model, model_opts = load_test_multitask_model(opts, opts.model) command_fn = { fn.__name__: fn for fn in [extract, estimate, classify] @@ -238,7 +238,7 @@ def main(): ] transform = TransformPipe.build_from(data_transform) - command_fn(opts, vocabs_dict, model.to(opts.device), model_opt, transform) + command_fn(opts, vocabs_dict, model.to(opts.device), model_opts, transform) if __name__ == '__main__': diff --git a/tools/average_models.py b/tools/average_models.py index 9e053a8c..ce714f92 100755 --- a/tools/average_models.py +++ b/tools/average_models.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from onmt.bin.average_models import main +from mammoth.bin.average_models import main if __name__ == "__main__": diff --git a/tools/embeddings_to_torch.py b/tools/embeddings_to_torch.py index 3bdb1fb2..00f2d981 100755 --- a/tools/embeddings_to_torch.py +++ b/tools/embeddings_to_torch.py @@ -4,7 +4,7 @@ import six import argparse import torch -from onmt.utils.logging import init_logger, logger +from mammoth.utils.logging import init_logger, logger # FIXME haven't touched that file yet... @@ -79,48 +79,48 @@ def main(): parser.add_argument('-verbose', action="store_true", default=False) parser.add_argument('-skip_lines', type=int, default=0, help="Skip first lines of the embedding file") parser.add_argument('-type', choices=["GloVe", "word2vec"], default="GloVe") - opt = parser.parse_args() + opts = parser.parse_args() - enc_vocab, dec_vocab = get_vocabs(opt.dict_file) + enc_vocab, dec_vocab = get_vocabs(opts.dict_file) # Read in embeddings - skip_lines = 1 if opt.type == "word2vec" else opt.skip_lines - if opt.emb_file_both is not None: - if opt.emb_file_enc is not None: + skip_lines = 1 if opts.type == "word2vec" else opts.skip_lines + if opts.emb_file_both is not None: + if opts.emb_file_enc is not None: raise ValueError("If --emb_file_both is passed in, you should not" "set --emb_file_enc.") - if opt.emb_file_dec is not None: + if opts.emb_file_dec is not None: raise ValueError("If --emb_file_both is passed in, you should not" "set --emb_file_dec.") set_of_src_and_tgt_vocab = set(enc_vocab.stoi.keys()) | set(dec_vocab.stoi.keys()) - logger.info("Reading encoder and decoder embeddings from {}".format(opt.emb_file_both)) - src_vectors, total_vec_count = read_embeddings(opt.emb_file_both, skip_lines, set_of_src_and_tgt_vocab) + logger.info("Reading encoder and decoder embeddings from {}".format(opts.emb_file_both)) + src_vectors, total_vec_count = read_embeddings(opts.emb_file_both, skip_lines, set_of_src_and_tgt_vocab) tgt_vectors = src_vectors logger.info("\tFound {} total vectors in file".format(total_vec_count)) else: - if opt.emb_file_enc is None: + if opts.emb_file_enc is None: raise ValueError( "If --emb_file_enc not provided. Please specify " "the file with encoder embeddings, or pass in " "--emb_file_both" ) - if opt.emb_file_dec is None: + if opts.emb_file_dec is None: raise ValueError( "If --emb_file_dec not provided. Please specify " "the file with encoder embeddings, or pass in " "--emb_file_both" ) - logger.info("Reading encoder embeddings from {}".format(opt.emb_file_enc)) - src_vectors, total_vec_count = read_embeddings(opt.emb_file_enc, skip_lines, filter_set=enc_vocab.stoi) + logger.info("Reading encoder embeddings from {}".format(opts.emb_file_enc)) + src_vectors, total_vec_count = read_embeddings(opts.emb_file_enc, skip_lines, filter_set=enc_vocab.stoi) logger.info("\tFound {} total vectors in file.".format(total_vec_count)) - logger.info("Reading decoder embeddings from {}".format(opt.emb_file_dec)) - tgt_vectors, total_vec_count = read_embeddings(opt.emb_file_dec, skip_lines, filter_set=dec_vocab.stoi) + logger.info("Reading decoder embeddings from {}".format(opts.emb_file_dec)) + tgt_vectors, total_vec_count = read_embeddings(opts.emb_file_dec, skip_lines, filter_set=dec_vocab.stoi) logger.info("\tFound {} total vectors in file".format(total_vec_count)) logger.info("After filtering to vectors in vocab:") logger.info("\t* enc: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(enc_vocab, src_vectors)) logger.info("\t* dec: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(dec_vocab, tgt_vectors)) # Write to file - enc_output_file = opt.output_file + ".enc.pt" - dec_output_file = opt.output_file + ".dec.pt" + enc_output_file = opts.output_file + ".enc.pt" + dec_output_file = opts.output_file + ".dec.pt" logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s" % (enc_output_file, dec_output_file)) torch.save(convert_to_torch_tensor(src_vectors, enc_vocab), enc_output_file) torch.save(convert_to_torch_tensor(tgt_vectors, dec_vocab), dec_output_file) diff --git a/tools/extract_embeddings.py b/tools/extract_embeddings.py index 41b2ded4..49ecbca9 100644 --- a/tools/extract_embeddings.py +++ b/tools/extract_embeddings.py @@ -2,14 +2,14 @@ import torch -import onmt -import onmt.model_builder +import mammoth +import mammoth.model_builder -from onmt.utils.parse import ArgumentParser -import onmt.opts +from mammoth.utils.parse import ArgumentParser +import mammoth.opts -from onmt.utils.misc import use_gpu -from onmt.utils.logging import init_logger, logger +from mammoth.utils.misc import use_gpu +from mammoth.utils.logging import init_logger, logger parser = argparse.ArgumentParser(description='translate.py') @@ -29,31 +29,31 @@ def write_embeddings(filename, dict, embeddings): def main(): dummy_parser = argparse.ArgumentParser(description='train.py') - onmt.opts.model_opts(dummy_parser) + mammoth.opts.model_opts(dummy_parser) dummy_opt = dummy_parser.parse_known_args([])[0] - opt = parser.parse_args() - opt.cuda = opt.gpu > -1 - if opt.cuda: - torch.cuda.set_device(opt.gpu) + opts = parser.parse_args() + opts.cuda = opts.gpu > -1 + if opts.cuda: + torch.cuda.set_device(opts.gpu) # Add in default model arguments, possibly added since training. - checkpoint = torch.load(opt.model, map_location=lambda storage, loc: storage) - model_opt = checkpoint['opt'] + checkpoint = torch.load(opts.model, map_location=lambda storage, loc: storage) + model_opts = checkpoint['opts'] fields = checkpoint['vocab'] src_dict = fields['src'].base_field.vocab # assumes src is text tgt_dict = fields['tgt'].base_field.vocab - model_opt = checkpoint['opt'] + model_opts = checkpoint['opts'] for arg in dummy_opt.__dict__: - if arg not in model_opt: - model_opt.__dict__[arg] = dummy_opt.__dict__[arg] + if arg not in model_opts: + model_opts.__dict__[arg] = dummy_opt.__dict__[arg] # build_base_model expects updated and validated opts - ArgumentParser.update_model_opts(model_opt) - ArgumentParser.validate_model_opts(model_opt) + ArgumentParser.update_model_opts(model_opts) + ArgumentParser.validate_model_opts(model_opts) - model = onmt.model_builder.build_base_model(model_opt, fields, use_gpu(opt), checkpoint) + model = mammoth.model_builder.build_base_model(model_opts, fields, use_gpu(opts), checkpoint) encoder = model.encoder # no encoder for LM task decoder = model.decoder @@ -61,10 +61,10 @@ def main(): decoder_embeddings = decoder.embeddings.word_lut.weight.data.tolist() logger.info("Writing source embeddings") - write_embeddings(opt.output_dir + "/src_embeddings.txt", src_dict, encoder_embeddings) + write_embeddings(opts.output_dir + "/src_embeddings.txt", src_dict, encoder_embeddings) logger.info("Writing target embeddings") - write_embeddings(opt.output_dir + "/tgt_embeddings.txt", tgt_dict, decoder_embeddings) + write_embeddings(opts.output_dir + "/tgt_embeddings.txt", tgt_dict, decoder_embeddings) logger.info('... done.') logger.info('Converting model...') diff --git a/tools/extract_vocabulary.py b/tools/extract_vocabulary.py index e003cc81..3c062e54 100644 --- a/tools/extract_vocabulary.py +++ b/tools/extract_vocabulary.py @@ -60,12 +60,12 @@ def main(): help="""Specifies 'src' or 'tgt' side for 'field' file_type.""", ) - opt = parser.parse_args() + opts = parser.parse_args() vocabulary = {} - if opt.file_type == 'text': + if opts.file_type == 'text': print("Reading input file...") - for batch in read_files_batch(opt.file): + for batch in read_files_batch(opts.file): for sentence in batch: for w in sentence: if w in vocabulary: @@ -74,19 +74,19 @@ def main(): vocabulary[w] = 1 print("Writing vocabulary file...") - with open(opt.out_file, "w") as f: + with open(opts.out_file, "w") as f: for w, count in sorted(vocabulary.items(), key=lambda x: x[1], reverse=True): f.write("{0}\n".format(w)) else: - if opt.side not in ['src', 'tgt']: + if opts.side not in ['src', 'tgt']: raise ValueError("If using -file_type='field', specifies 'src' or 'tgt' argument for -side.") import torch print("Reading input file...") - if not len(opt.file) == 1: + if not len(opts.file) == 1: raise ValueError("If using -file_type='field', only pass one argument for -file.") - vocabs = torch.load(opt.file[0]) - voc = dict(vocabs)[opt.side] + vocabs = torch.load(opts.file[0]) + voc = dict(vocabs)[opts.side] try: word_list = voc[0][1].base_field.vocab.itos @@ -94,7 +94,7 @@ def main(): word_list = voc[0][1].vocab.itos print("Writing vocabulary file...") - with open(opt.out_file, "wb") as f: + with open(opts.out_file, "wb") as f: for w in word_list: f.write(u"{0}\n".format(w).encode("utf-8")) diff --git a/tools/release_model.py b/tools/release_model.py index b716b115..dd437517 100644 --- a/tools/release_model.py +++ b/tools/release_model.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from onmt.bin.release_model import main +from mammoth.bin.release_model import main if __name__ == "__main__": diff --git a/tools/spm_to_vocab.py b/tools/spm_to_vocab.py index ba7d734d..f2371727 100644 --- a/tools/spm_to_vocab.py +++ b/tools/spm_to_vocab.py @@ -3,7 +3,7 @@ # counts) import sys import math -from onmt.constants import DefaultTokens +from mammoth.constants import DefaultTokens OMIT = (DefaultTokens.UNK, DefaultTokens.BOS, DefaultTokens.EOS) diff --git a/tools/test_rouge.py b/tools/test_rouge.py index 436edab9..12ccc35d 100644 --- a/tools/test_rouge.py +++ b/tools/test_rouge.py @@ -7,7 +7,7 @@ import sys import codecs -from onmt.utils.logging import init_logger, logger +from mammoth.utils.logging import init_logger, logger def eval_rouge(cand, ref): diff --git a/train.py b/train.py index 1b03c9bc..1648b083 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from onmt.bin.train import main +from mammoth.bin.train import main if __name__ == "__main__": diff --git a/translate.py b/translate.py index 5ca91336..c27cbfac 100644 --- a/translate.py +++ b/translate.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from onmt.bin.translate import main +from mammoth.bin.translate import main if __name__ == "__main__":