diff --git a/build_vocab.py b/build_vocab.py index fabea1b2..577c2c1c 100644 --- a/build_vocab.py +++ b/build_vocab.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from mammoth.bin.build_vocab import main +from onmt.bin.build_vocab import main if __name__ == "__main__": diff --git a/docs/source/CONTRIBUTING.md b/docs/source/CONTRIBUTING.md index 717a3ca0..7ad1425b 100644 --- a/docs/source/CONTRIBUTING.md +++ b/docs/source/CONTRIBUTING.md @@ -5,7 +5,7 @@ OpenNMT-py is a community developed project and we love developer contributions. ## Guidelines Before sending a PR, please do this checklist first: -- Please run `mammoth/tests/pull_request_chk.sh` and fix any errors. When adding new functionality, also add tests to this script. Included checks: +- Please run `onmt/tests/pull_request_chk.sh` and fix any errors. When adding new functionality, also add tests to this script. Included checks: 1. flake8 check for coding style; 2. unittest; 3. continuous integration tests listed in `.travis.yml`. diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md new file mode 100644 index 00000000..66a0b6b0 --- /dev/null +++ b/docs/source/FAQ.md @@ -0,0 +1,37 @@ +# Questions + +## What is the intuition behind fixed-length memory bank? +Specifically, for `lin` , the intuition behind the structured attention is to replace pooling over the hidden representations with multi-hop attentive representations (fixed length). What is the benefit for transforming source sequence representations into a fixed length memory bank? + +Push the model to be more language-agnostic. Sentence length tends to be language dependent. For example, French tends to produce longer sentences than English. + +Does the attention in attention bridge act as an enhancement of encoder? Will the attention bridge bring any benefits to decoders? +1. If we view attention bridge as a part of encoder, will the overall model be a partially shared encoder (separate lower layers and shared attention bridge) + separate decoders? + +If the shared attention is viewed is a part of encoder for many2one translation and a part of decoder for one2many translation, the shared attention module encoder some language-independent information to enhance encoding or decoding? + +## Models are saved with encoder, decoder, and generator. What is generator? +The generator contains Linear + activation (softmax or sparsesoftmax). + +### Why we need to separately save “generator”? +It seems unnecessary to separate the generator. Activation functions do not contain trainable parameters. + + +## What is the difference between `intermediate_output` and `encoder_output`? [🔗](./onmt/attention_bridge.py#L91) + +`intermediate_output` is the intermediate output of stacked n-layered attention bridges. `encoder_output` is literally the output of encoder, which was reused in the n-layered `PerceiverAttentionBridgeLayer`. + +For `PerceiverAttentionBridgeLayer` where the encoder output is projected into fixed length via `lattent_array`. But why? + +For `PerceiverAttentionBridgeLayer` : + +`intermediate_output` and `encoder_output` are used as: + +```python + S, B, F = encoder_output.shape + if intermediate_output is not None: + cross_query = intermediate_output + else: + cross_query = self.latent_array.unsqueeze(0).expand(B, -1, -1) + encoder_output = encoder_output.transpose(0, 1) +``` \ No newline at end of file diff --git a/docs/source/attention_bridges.md b/docs/source/attention_bridges.md index 3b014dbd..0080a85f 100644 --- a/docs/source/attention_bridges.md +++ b/docs/source/attention_bridges.md @@ -1,7 +1,7 @@ # Attention Bridge -The embeddings are generated through the self-attention mechanism ([Attention Bridge](./mammoth/modules/attention_bridge.py)) of the encoder and establish a connection with language-specific decoders that focus their attention on these embeddings. This is why they are referred to as 'bridges'. This architectural element serves to link the encoded information with the decoding process, enhancing the flow of information between different stages of language processing. +The embeddings are generated through the self-attention mechanism ([Attention Bridge](./onmt/attention_bridge.py)) of the encoder and establish a connection with language-specific decoders that focus their attention on these embeddings. This is why they are referred to as 'bridges'. This architectural element serves to link the encoded information with the decoding process, enhancing the flow of information between different stages of language processing. There are five types of attention mechanism implemented: @@ -61,7 +61,7 @@ The `PerceiverAttentionBridgeLayer` involves a multi-headed dot product self-att 3. **Linear Layer**: After normalization, the data is fed into a linear layer. This linear transformation can be seen as a learned projection of the attention-weighted data into a new space. -4. **ReLU Activation**: The output of the linear layer undergoes the Rectified Linear Unit (ReLU) activation function. +4. **ReLU Activation**: The output of the linear layer undergoes the Rectified Linear Unit (ReLU) activation function. 5. **Linear Layer (Second)**: Another linear layer is applied to the ReLU-activated output. @@ -72,11 +72,11 @@ The `PerceiverAttentionBridgeLayer` involves a multi-headed dot product self-att The process described involves dot product self-attention. The steps are as follows: 1. **Input Transformation**: Given an input matrix $\mathbf{H} \in \mathbb{R}^{d_h \times n}$, two sets of learned weight matrices are used to transform the input. These weight matrices are $\mathbf{W}_1 \in \mathbb{R}^{d_h \times d_a}$ and $\mathbf{W}_2 \in \mathbb{R}^{d_h \times d_a}$. The multiplication of $\mathbf{H}$ with $\mathbf{W}_1$ and $\mathbf{W}_2$ produces matrices $\mathbf{V}$ and $\mathbf{K}$, respectively: - + - $\mathbf{V} = \mathbf{H} \mathbf{W}_1$ - $\mathbf{K} = \mathbf{H} \mathbf{W}_2$ -2. **Attention Calculation**: The core attention calculation involves three matrices: $\mathbf{Q} \in \mathbb{R}^{d_h \times n}$, $\mathbf{K}$ (calculated previously), and $\mathbf{V}$ (calculated previously). The dot product of $\mathbf{Q}$ and $\mathbf{K}^\top$ is divided by the square root of the dimensionality of the input features ($\sqrt{d_h}$). +2. **Attention Calculation**: The core attention calculation involves three matrices: $\mathbf{Q} \in \mathbb{R}^{d_h \times n}$, $\mathbf{K}$ (calculated previously), and $\mathbf{V}$ (calculated previously). The dot product of $\mathbf{Q}$ and $\mathbf{K}^\top$ is divided by the square root of the dimensionality of the input features ($\sqrt{d_h}$). The final attended output is calculated by multiplying the attention weights with the $\mathbf{V}$ matrix: $\mathbf{H}^\prime = \operatorname{Softmax}(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_h}})\mathbf{V}$ @@ -86,4 +86,5 @@ The TransformerEncoderLayer employs multi-headed dot product self-attention (by ## FeedForwardAttentionBridgeLayer -The `FeedForwardAttentionBridgeLayer` module applies a sequence of linear transformations and `ReLU` activations to the input data, followed by an attention bridge normalization, enhancing the connectivity between different parts of the model. +The `FeedForwardAttentionBridgeLayer` module applies a sequence of linear transformations and `ReLU` activations to the input data, followed by an attention bridge normalization, enhancing the connectivity between different parts of the model. + diff --git a/docs/source/index.rst b/docs/source/index.rst index 7f04a9db..abb4c23b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,8 +38,8 @@ Contents :caption: API :maxdepth: 2 - mammoth.rst - mammoth.modules.rst - mammoth.translation.rst - mammoth.translate.translation_server.rst - mammoth.inputters.rst + onmt.rst + onmt.modules.rst + onmt.translation.rst + onmt.translate.translation_server.rst + onmt.inputters.rst diff --git a/docs/source/mammoth.inputters.rst b/docs/source/mammoth.inputters.rst deleted file mode 100644 index b95aae67..00000000 --- a/docs/source/mammoth.inputters.rst +++ /dev/null @@ -1,20 +0,0 @@ -Data Loaders -================= - -Data Readers -------------- - -.. autoexception:: mammoth.inputters.datareader_base.MissingDependencyException - -.. autoclass:: mammoth.inputters.DataReaderBase - :members: - -.. autoclass:: mammoth.inputters.TextDataReader - :members: - - -Dataset --------- - -.. autoclass:: mammoth.inputters.Dataset - :members: diff --git a/docs/source/mammoth.modules.rst b/docs/source/mammoth.modules.rst deleted file mode 100644 index de33bfd5..00000000 --- a/docs/source/mammoth.modules.rst +++ /dev/null @@ -1,109 +0,0 @@ -Modules -============= - -Core Modules ------------- - -.. autoclass:: mammoth.modules.Embeddings - :members: - - -Encoders ---------- - -.. autoclass:: mammoth.encoders.EncoderBase - :members: - -.. autoclass:: mammoth.encoders.MeanEncoder - :members: - -.. autoclass:: mammoth.encoders.RNNEncoder - :members: - - -Decoders ---------- - - -.. autoclass:: mammoth.decoders.DecoderBase - :members: - -.. autoclass:: mammoth.decoders.decoder.RNNDecoderBase - :members: - -.. autoclass:: mammoth.decoders.StdRNNDecoder - :members: - -.. autoclass:: mammoth.decoders.InputFeedRNNDecoder - :members: - -Attention ----------- - -.. autoclass:: mammoth.modules.AverageAttention - :members: - -.. autoclass:: mammoth.modules.GlobalAttention - :members: - - - -Architecture: Transformer ----------------------------- - -.. autoclass:: mammoth.modules.PositionalEncoding - :members: - -.. autoclass:: mammoth.modules.position_ffn.PositionwiseFeedForward - :members: - -.. autoclass:: mammoth.encoders.TransformerEncoder - :members: - -.. autoclass:: mammoth.decoders.TransformerDecoder - :members: - -.. autoclass:: mammoth.modules.MultiHeadedAttention - :members: - :undoc-members: - - -Architecture: Conv2Conv ----------------------------- - -(These methods are from a user contribution -and have not been thoroughly tested.) - - -.. autoclass:: mammoth.encoders.CNNEncoder - :members: - - -.. autoclass:: mammoth.decoders.CNNDecoder - :members: - -.. autoclass:: mammoth.modules.ConvMultiStepAttention - :members: - -.. autoclass:: mammoth.modules.WeightNormConv2d - :members: - -Architecture: SRU ----------------------------- - -.. autoclass:: mammoth.models.sru.SRU - :members: - - -Copy Attention --------------- - -.. autoclass:: mammoth.modules.CopyGenerator - :members: - - -Structured Attention -------------------------------------------- - -.. autoclass:: mammoth.modules.structured_attention.MatrixTree - :members: diff --git a/docs/source/mammoth.rst b/docs/source/mammoth.rst deleted file mode 100644 index cd3d2a8f..00000000 --- a/docs/source/mammoth.rst +++ /dev/null @@ -1,32 +0,0 @@ -Framework -================= - -Model ------ - -.. autoclass:: mammoth.models.NMTModel - :members: - -Trainer -------- - -.. autoclass:: mammoth.Trainer - :members: - - -.. autoclass:: mammoth.utils.Statistics - :members: - -Loss ----- - - -.. autoclass:: mammoth.utils.loss.LossComputeBase - :members: - - -Optimizer ---------- - -.. autoclass:: mammoth.utils.Optimizer - :members: diff --git a/docs/source/mammoth.translate.translation_server.rst b/docs/source/mammoth.translate.translation_server.rst deleted file mode 100644 index 0bc9dad7..00000000 --- a/docs/source/mammoth.translate.translation_server.rst +++ /dev/null @@ -1,21 +0,0 @@ -Server -====== - - -Models -------------- - -.. autoclass:: mammoth.translate.translation_server.ServerModel - :members: - - -Core Server ------------- - -.. autoexception:: mammoth.translate.translation_server.ServerModelError - -.. autoclass:: mammoth.translate.translation_server.Timer - :members: - -.. autoclass:: mammoth.translate.translation_server.TranslationServer - :members: diff --git a/docs/source/mammoth.translation.rst b/docs/source/mammoth.translation.rst deleted file mode 100644 index 6b075f96..00000000 --- a/docs/source/mammoth.translation.rst +++ /dev/null @@ -1,39 +0,0 @@ -Translation -================== - -Translations -------------- - -.. autoclass:: mammoth.translate.Translation - :members: - -Translator Class ------------------ - -.. autoclass:: mammoth.translate.Translator - :members: - -.. autoclass:: mammoth.translate.TranslationBuilder - :members: - - -Decoding Strategies --------------------- -.. autoclass:: mammoth.translate.DecodeStrategy - :members: - -.. autoclass:: mammoth.translate.BeamSearch - :members: - -.. autofunction:: mammoth.translate.greedy_search.sample_with_temperature - -.. autoclass:: mammoth.translate.GreedySearch - :members: - -Scoring --------- -.. autoclass:: mammoth.translate.penalties.PenaltyBuilder - :members: - -.. autoclass:: mammoth.translate.GNMTGlobalScorer - :members: diff --git a/docs/source/onmt.inputters.rst b/docs/source/onmt.inputters.rst new file mode 100644 index 00000000..99507e29 --- /dev/null +++ b/docs/source/onmt.inputters.rst @@ -0,0 +1,20 @@ +Data Loaders +================= + +Data Readers +------------- + +.. autoexception:: onmt.inputters.datareader_base.MissingDependencyException + +.. autoclass:: onmt.inputters.DataReaderBase + :members: + +.. autoclass:: onmt.inputters.TextDataReader + :members: + + +Dataset +-------- + +.. autoclass:: onmt.inputters.Dataset + :members: diff --git a/docs/source/onmt.modules.rst b/docs/source/onmt.modules.rst new file mode 100644 index 00000000..a3ef216e --- /dev/null +++ b/docs/source/onmt.modules.rst @@ -0,0 +1,109 @@ +Modules +============= + +Core Modules +------------ + +.. autoclass:: onmt.modules.Embeddings + :members: + + +Encoders +--------- + +.. autoclass:: onmt.encoders.EncoderBase + :members: + +.. autoclass:: onmt.encoders.MeanEncoder + :members: + +.. autoclass:: onmt.encoders.RNNEncoder + :members: + + +Decoders +--------- + + +.. autoclass:: onmt.decoders.DecoderBase + :members: + +.. autoclass:: onmt.decoders.decoder.RNNDecoderBase + :members: + +.. autoclass:: onmt.decoders.StdRNNDecoder + :members: + +.. autoclass:: onmt.decoders.InputFeedRNNDecoder + :members: + +Attention +---------- + +.. autoclass:: onmt.modules.AverageAttention + :members: + +.. autoclass:: onmt.modules.GlobalAttention + :members: + + + +Architecture: Transformer +---------------------------- + +.. autoclass:: onmt.modules.PositionalEncoding + :members: + +.. autoclass:: onmt.modules.position_ffn.PositionwiseFeedForward + :members: + +.. autoclass:: onmt.encoders.TransformerEncoder + :members: + +.. autoclass:: onmt.decoders.TransformerDecoder + :members: + +.. autoclass:: onmt.modules.MultiHeadedAttention + :members: + :undoc-members: + + +Architecture: Conv2Conv +---------------------------- + +(These methods are from a user contribution +and have not been thoroughly tested.) + + +.. autoclass:: onmt.encoders.CNNEncoder + :members: + + +.. autoclass:: onmt.decoders.CNNDecoder + :members: + +.. autoclass:: onmt.modules.ConvMultiStepAttention + :members: + +.. autoclass:: onmt.modules.WeightNormConv2d + :members: + +Architecture: SRU +---------------------------- + +.. autoclass:: onmt.models.sru.SRU + :members: + + +Copy Attention +-------------- + +.. autoclass:: onmt.modules.CopyGenerator + :members: + + +Structured Attention +------------------------------------------- + +.. autoclass:: onmt.modules.structured_attention.MatrixTree + :members: diff --git a/docs/source/onmt.rst b/docs/source/onmt.rst new file mode 100644 index 00000000..5ae056ce --- /dev/null +++ b/docs/source/onmt.rst @@ -0,0 +1,32 @@ +Framework +================= + +Model +----- + +.. autoclass:: onmt.models.NMTModel + :members: + +Trainer +------- + +.. autoclass:: onmt.Trainer + :members: + + +.. autoclass:: onmt.utils.Statistics + :members: + +Loss +---- + + +.. autoclass:: onmt.utils.loss.LossComputeBase + :members: + + +Optimizer +--------- + +.. autoclass:: onmt.utils.Optimizer + :members: diff --git a/docs/source/onmt.translate.translation_server.rst b/docs/source/onmt.translate.translation_server.rst new file mode 100644 index 00000000..3426fade --- /dev/null +++ b/docs/source/onmt.translate.translation_server.rst @@ -0,0 +1,21 @@ +Server +====== + + +Models +------------- + +.. autoclass:: onmt.translate.translation_server.ServerModel + :members: + + +Core Server +------------ + +.. autoexception:: onmt.translate.translation_server.ServerModelError + +.. autoclass:: onmt.translate.translation_server.Timer + :members: + +.. autoclass:: onmt.translate.translation_server.TranslationServer + :members: diff --git a/docs/source/onmt.translation.rst b/docs/source/onmt.translation.rst new file mode 100644 index 00000000..bb6f5a5d --- /dev/null +++ b/docs/source/onmt.translation.rst @@ -0,0 +1,39 @@ +Translation +================== + +Translations +------------- + +.. autoclass:: onmt.translate.Translation + :members: + +Translator Class +----------------- + +.. autoclass:: onmt.translate.Translator + :members: + +.. autoclass:: onmt.translate.TranslationBuilder + :members: + + +Decoding Strategies +-------------------- +.. autoclass:: onmt.translate.DecodeStrategy + :members: + +.. autoclass:: onmt.translate.BeamSearch + :members: + +.. autofunction:: onmt.translate.greedy_search.sample_with_temperature + +.. autoclass:: onmt.translate.GreedySearch + :members: + +Scoring +-------- +.. autoclass:: onmt.translate.penalties.PenaltyBuilder + :members: + +.. autoclass:: onmt.translate.GNMTGlobalScorer + :members: diff --git a/docs/source/options/build_vocab.rst b/docs/source/options/build_vocab.rst index 95bdc79b..57fda68e 100644 --- a/docs/source/options/build_vocab.rst +++ b/docs/source/options/build_vocab.rst @@ -2,7 +2,7 @@ Build Vocab =========== .. argparse:: - :filename: ../mammoth/bin/build_vocab.py + :filename: ../onmt/bin/build_vocab.py :func: _get_parser :prog: build_vocab.py diff --git a/docs/source/options/server.rst b/docs/source/options/server.rst index b883d4fe..63b2676f 100644 --- a/docs/source/options/server.rst +++ b/docs/source/options/server.rst @@ -2,6 +2,6 @@ Server ========= .. argparse:: - :filename: ../mammoth/bin/server.py + :filename: ../onmt/bin/server.py :func: _get_parser :prog: server.py \ No newline at end of file diff --git a/docs/source/options/train.rst b/docs/source/options/train.rst index 066aa160..67dc1cb2 100644 --- a/docs/source/options/train.rst +++ b/docs/source/options/train.rst @@ -2,6 +2,6 @@ Train ===== .. argparse:: - :filename: ../mammoth/bin/train.py + :filename: ../onmt/bin/train.py :func: _get_parser :prog: train.py \ No newline at end of file diff --git a/docs/source/options/translate.rst b/docs/source/options/translate.rst index 4b6244b7..db0423a4 100644 --- a/docs/source/options/translate.rst +++ b/docs/source/options/translate.rst @@ -2,6 +2,6 @@ Translate ========= .. argparse:: - :filename: ../mammoth/bin/translate.py + :filename: ../onmt/bin/translate.py :func: _get_parser :prog: translate.py \ No newline at end of file diff --git a/mammoth/__init__.py b/mammoth/__init__.py deleted file mode 100644 index fd6ae773..00000000 --- a/mammoth/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -""" Main entry point of the Mammoth library """ -import mammoth.inputters -import mammoth.models -import mammoth.utils -import mammoth.modules -import mammoth.opts -from mammoth.trainer import Trainer -import sys -import mammoth.utils.optimizers - -mammoth.utils.optimizers.Optim = mammoth.utils.optimizers.Optimizer -sys.modules["mammoth.Optim"] = mammoth.utils.optimizers - -__all__ = [ - mammoth.inputters, - mammoth.models, - mammoth.utils, - mammoth.modules, - mammoth.opts, - "Trainer" -] - -__version__ = "2.2.0" diff --git a/mammoth/bin/translate.py b/mammoth/bin/translate.py deleted file mode 100644 index 8c86bbf0..00000000 --- a/mammoth/bin/translate.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -from mammoth.utils.logging import init_logger -from mammoth.utils.misc import split_corpus -from mammoth.translate.translator import build_translator -# from mammoth.inputters.text_dataset import InferenceDataReader -from mammoth.transforms import get_transforms_cls, make_transforms, TransformPipe - -import mammoth.opts as opts -from mammoth.distributed import TaskSpecs -from mammoth.utils.parse import ArgumentParser - - -def translate(opts): - ArgumentParser.validate_translate_opts(opts) - ArgumentParser._get_all_transform_translate(opts) - ArgumentParser._validate_transforms_opts(opts) - ArgumentParser.validate_translate_opts_dynamic(opts) - logger = init_logger(opts.log_file) - - encoder_adapter_ids = set() - for layer_stack_idx, stack in enumerate(opts.stack['encoder']): - if 'adapters' in stack: - for group_id, sub_id in stack['adapters']: - encoder_adapter_ids.add((layer_stack_idx, group_id, sub_id)) - decoder_adapter_ids = set() - for layer_stack_idx, stack in enumerate(opts.stack['decoder']): - if 'adapters' in stack: - for group_id, sub_id in stack['adapters']: - decoder_adapter_ids.add((layer_stack_idx, group_id, sub_id)) - - logger.info( - 'It is ok that src_vocab and tgt_vocab are None here. ' - 'The vocabs are separately loaded in model_builder.' - ) - task = TaskSpecs( - node_rank=None, - local_rank=None, - src_lang=opts.src_lang, - tgt_lang=opts.tgt_lang, - encoder_id=[stack['id'] for stack in opts.stack['encoder']], - decoder_id=[stack['id'] for stack in opts.stack['decoder']], - corpus_id='trans', - weight=1, - corpus_opts=dict(), - src_vocab=None, - tgt_vocab=None, - encoder_adapter_ids=encoder_adapter_ids, - decoder_adapter_ids=decoder_adapter_ids, - ) - - translator = build_translator(opts, task, logger=logger, report_score=True) - - # data_reader = InferenceDataReader(opts.src, opts.tgt, opts.src_feats) - src_shards = split_corpus(opts.src, opts.shard_size) - tgt_shards = split_corpus(opts.tgt, opts.shard_size) - features_shards = [] - features_names = [] - for feat_name, feat_path in opts.src_feats.items(): - features_shards.append(split_corpus(feat_path, opts.shard_size)) - features_names.append(feat_name) - shard_pairs = zip(src_shards, tgt_shards, *features_shards) - - # Build transforms - transforms_cls = get_transforms_cls(opts._all_transform) - transforms = make_transforms(opts, transforms_cls, translator.vocabs, task=task) - data_transform = [ - transforms[name] for name in opts.transforms if name in transforms - ] - transform = TransformPipe.build_from(data_transform) - - for i, (src_shard, tgt_shard, *feats_shard) in enumerate(shard_pairs): - logger.info("Translating shard %d." % i) - translator.translate_dynamic( - src=src_shard, - transform=transform, - # src_feats=feats_shard, # TODO: put me back in - tgt=tgt_shard, - batch_size=opts.batch_size, - batch_type=opts.batch_type, - attn_debug=opts.attn_debug, - align_debug=opts.align_debug - ) - - -def _get_parser(): - parser = ArgumentParser(description='translate.py') - - opts.config_opts(parser) - opts.translate_opts(parser, dynamic=True) - opts.build_bilingual_model(parser) - return parser - - -def main(): - parser = _get_parser() - - opts = parser.parse_args() - translate(opts) - - -if __name__ == "__main__": - main() diff --git a/mammoth/distributed/__init__.py b/mammoth/distributed/__init__.py deleted file mode 100644 index 5a032c0c..00000000 --- a/mammoth/distributed/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Module defining distributed communications utilities.""" -from mammoth.distributed.communication import ( - all_gather_list, - batch_producer, - consumer, - broadcast_tensors, - only_ready_reduce_and_rescale_grads, - ErrorHandler, -) -from mammoth.distributed.contexts import DeviceContext, WorldContext, DeviceContextEnum -from mammoth.distributed.tasks import ( - TaskSpecs, - TaskQueueManager, - DatasetMetadata, - TASK_DISTRIBUTION_STRATEGIES, -) - -__all__ = [ - "all_gather_list", - "batch_producer", - "broadcast_tensors", - "consumer", - "only_ready_reduce_and_rescale_grads", - "ErrorHandler", - "DeviceContext", - "WorldContext", - "DeviceContextEnum", - "TASK_DISTRIBUTION_STRATEGIES", - "DatasetMetadata", - "TaskQueueManager", - "TaskSpecs", -] diff --git a/mammoth/distributed/communication.py b/mammoth/distributed/communication.py deleted file mode 100644 index 687da4b9..00000000 --- a/mammoth/distributed/communication.py +++ /dev/null @@ -1,282 +0,0 @@ -"""Module defining low-level comunication utilities (initialization, brodcasting, etc.)""" -import math -import os -import pickle -import signal - -import torch -import torch.distributed - -from mammoth.utils.logging import init_logger, logger -from mammoth.utils.misc import set_random_seed - - -def multi_init(opts, global_rank): - dist_init_method = 'tcp://{master_ip}:{master_port}'.format(master_ip=opts.master_ip, master_port=opts.master_port) - - dist_world_size = opts.world_size - torch.distributed.init_process_group( - backend=opts.gpu_backend, - init_method=dist_init_method, - rank=global_rank, - world_size=dist_world_size, - ) - - gpu_rank = torch.distributed.get_rank() - - return gpu_rank - - -def broadcast_tensors(tensors, src=0, group=None): - for t in tensors: - if group is None: - torch.distributed.broadcast(t, src) - else: - torch.distributed.broadcast(t, src, group=group) - - -def only_ready_reduce_and_rescale_grads(named_parameters, group=None): - """ - Gradient synch tolerant to missing grads. - - Missing grads occur when some parameters are not trained between two - gradient synchs, e.g. the embeddings of a low-resource language with low - sampling weight. - - The algorithm first uses the 'has_grad' attribute set by the forward hook - 'has_grad_hook'. This hook ensures that all parameters of the modules - selected for use during the current training computation have 'has_grad' - set to True. This gives the list of parameters that have been trained on - this device ("ready"). - - A bit mask covering the parameters that are ready on this device is - communicated to the other devices in the group. The bit masks are reduced - using summation. The sum gives the number of real gradients for that - parameter, and can be used for normalization. - - If a parameter is ready on any device, all devices communicate a value. - Devices on which the parameter is ready communicate the actual gradient, - while devices on which it is not ready communicate a dummy zero tensor - instead. The sum computed previously is used for normalization. - - Args: - named_parameters: tuples of (str, Parameter) defining the parameters to consider - group: torch.distributed communication group - """ - # Set missing gradients to zero, keeping track of true gradients - require_grad = [(name, p) for (name, p) in named_parameters if p.requires_grad] - if not require_grad: - # Exit early if the component has no parameters that require a gradient - return - device = require_grad[0][1].device - ready_list = [] - for name, p in require_grad: - if hasattr(p, 'has_grad') and p.has_grad: - ready_list.append(1.0) - else: - ready_list.append(0.0) - if p.grad is None: - p.grad = torch.zeros_like(p) - - # Communicate the ready bits, and reduce them using summation. - # This gives the number of non-dummy gradients participating, for normalization - ready_t = torch.tensor(ready_list).to(device) - if group is None: - torch.distributed.all_reduce(ready_t) - else: - torch.distributed.all_reduce(ready_t, group=group) - rescale_denoms = ready_t # after reduction - - # Omit if all nodes sent a zero ready bit - denoms_mask = (rescale_denoms > 0).cpu() - params_with_grad = [p for ((name, p), m) in zip(require_grad, denoms_mask) if m] - grads = [p.grad.data for p in params_with_grad] - rescale_denoms = [denom for (denom, m) in zip(rescale_denoms, denoms_mask) if m] - assert len(grads) == len(rescale_denoms) - if len(grads) == 0: - return - - # If not, then set has_grad also on devices that did not train the parameter themselves. - # They now have a grad that they received from the other devices. - for name, p in require_grad: - p.has_grad = True - - # All devices communicate either a real gradient or a dummy zeros of the same size - # Can not use rescale_denom, as each grad may have its own denominator - all_reduce_and_rescale_tensors(grads, rescale_denom=1, group=group) - - # Normalize using the previously computed values - for grad, denom in zip(grads, rescale_denoms): - if denom > 1: - grad.div_(denom) - # Note: p.has_grad is reused in the optimizer to prevent the untrained components from being stepped - - -def all_reduce_and_rescale_tensors(tensors, rescale_denom, group=None, buffer_size=10485760): - """ - All-reduce and rescale tensors in chunks of the specified size. - - Args: - tensors: list of Tensors to all-reduce - rescale_denom: denominator for rescaling summed Tensors - buffer_size: all-reduce chunk size in bytes - """ - # buffer size in bytes, determine equiv. # of elements based on data type - buffer_t = tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_() - buffer = [] - - def all_reduce_buffer(): - # copy tensors into buffer_t - offset = 0 - for t in buffer: - numel = t.numel() - buffer_t[offset:offset + numel].copy_(t.view(-1)) - offset += numel - - # all-reduce and rescale - if group is None: - torch.distributed.all_reduce(buffer_t[:offset]) - else: - torch.distributed.all_reduce(buffer_t[:offset], group=group) - buffer_t.div_(rescale_denom) - - # copy all-reduced buffer back into tensors - offset = 0 - for t in buffer: - numel = t.numel() - t.view(-1).copy_(buffer_t[offset:offset + numel]) - offset += numel - - filled = 0 - for t in tensors: - sz = t.numel() * t.element_size() - if sz > buffer_size: - # tensor is bigger than buffer, all-reduce and rescale directly - if group is None: - torch.distributed.all_reduce(t) - else: - torch.distributed.all_reduce(t, group=group) - t.div_(rescale_denom) - elif filled + sz > buffer_size: - # buffer is full, all-reduce and replace buffer with grad - all_reduce_buffer() - buffer = [t] - filled = sz - else: - # add tensor to buffer - buffer.append(t) - filled += sz - - if len(buffer) > 0: - all_reduce_buffer() - - -def all_gather_list(data, max_size=4096): - """Gathers arbitrary data from all nodes into a list.""" - world_size = torch.distributed.get_world_size() - if not hasattr(all_gather_list, '_in_buffer') or max_size != all_gather_list._in_buffer.size(): - all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) - all_gather_list._out_buffers = [torch.cuda.ByteTensor(max_size) for i in range(world_size)] - in_buffer = all_gather_list._in_buffer - out_buffers = all_gather_list._out_buffers - - enc = pickle.dumps(data) - enc_size = len(enc) - if enc_size + 2 > max_size: - raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2)) - assert max_size < 255 * 256 - in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k - in_buffer[1] = enc_size % 255 - in_buffer[2:enc_size + 2] = torch.ByteTensor(list(enc)) - - torch.distributed.all_gather(out_buffers, in_buffer.cuda()) - - results = [] - for i in range(world_size): - out_buffer = out_buffers[i] - size = (255 * out_buffer[0].item()) + out_buffer[1].item() - - bytes_list = bytes(out_buffer[2:size + 2].tolist()) - result = pickle.loads(bytes_list) - results.append(result) - return results - - -class ErrorHandler(object): - """A class that listens for exceptions in children processes and propagates - the tracebacks to the parent process.""" - - def __init__(self, error_queue): - """init error handler""" - import signal - import threading - - self.error_queue = error_queue - self.children_pids = [] - self.error_thread = threading.Thread(target=self.error_listener, daemon=True) - self.error_thread.start() - signal.signal(signal.SIGUSR1, self.signal_handler) - - def add_child(self, pid): - """error handler""" - self.children_pids.append(pid) - - def error_listener(self): - """error listener""" - (rank, original_trace) = self.error_queue.get() - self.error_queue.put((rank, original_trace)) - os.kill(os.getpid(), signal.SIGUSR1) - - def signal_handler(self, signalnum, stackframe): - """signal handler""" - for pid in self.children_pids: - os.kill(pid, signal.SIGINT) # kill children processes - (rank, original_trace) = self.error_queue.get() - msg = """\n\n-- Tracebacks above this line can probably - be ignored --\n\n""" - msg += original_trace - raise Exception(msg) - - -def batch_producer(generator_to_serve, queue, semaphore, opts, device_id): - """Produce batches to `queues` from `generator_to_serve`.""" - log_level = "INFO" if opts.verbose or device_id == 0 else "WARNING" - init_logger(opts.log_file, log_level=log_level) - set_random_seed(opts.seed, False) - logger.info("BATCH PRODUCER") - logger.info(generator_to_serve) - - for batch, metadata, communication_batch_id in generator_to_serve: - semaphore.acquire() - # Move batch to correspond device_id when consumer iterate - # hack to dodge unpicklable `dict_keys` - # batch.fields = list(batch.fields) - queue.put((batch, metadata, communication_batch_id)) - - -def consumer(process_fn, opts, device_context, error_queue, batch_queue, semaphore, task_queue_manager): - """Run `process_fn` on `device_id` with data from `batch_queue`.""" - try: - logger.info( - f'global_rank {device_context.global_rank} ' - f'node_rank {device_context.node_rank} ' - f'local_rank {device_context.local_rank}' - ) - logger.info(f'opts.gpu_ranks {opts.gpu_ranks}') - multi_init(opts, device_context.global_rank) - # error_queue not passed (is this intentional?) - process_fn( - opts, - device_context=device_context, - batch_queue=batch_queue, - semaphore=semaphore, - task_queue_manager=task_queue_manager, - ) - - except KeyboardInterrupt: - pass # killed by parent, do nothing - except Exception: - # propagate exception to parent process, keeping original traceback - import traceback - - error_queue.put((opts.gpu_ranks[device_context.node_rank], traceback.format_exc())) diff --git a/mammoth/distributed/contexts.py b/mammoth/distributed/contexts.py deleted file mode 100644 index 8a8e4241..00000000 --- a/mammoth/distributed/contexts.py +++ /dev/null @@ -1,105 +0,0 @@ -from dataclasses import dataclass -from enum import Enum - - -class DeviceContextEnum(Enum): - CPU = 1 - SINGLE_GPU = 2 - MULTI_GPU = 3 - - -@dataclass -class WorldContext: - context: DeviceContextEnum - # Size of the world: total number of nodes, gpus on each node - n_nodes: int - gpus_per_node: int - - @property - def world_size(self): - """Total number of training GPUs""" - return self.n_nodes * self.gpus_per_node - - def is_distributed(self): - """When training is distributed over several devices, - multiprocessing is used to communicate gradients""" - return self.context == DeviceContextEnum.MULTI_GPU - - def is_gpu(self): - """Data tensors must be moved to the GPU for compute""" - return self.context != DeviceContextEnum.CPU - - def is_master(self): - """For code that should only run in one process: - - saving fully shared modules from one device only - - avoiding log spam when all devices would log the same result - """ - return not self.is_distributed() or self.global_rank == 0 - - def global_to_local(self, node_rank, local_rank): - assert node_rank is not None - assert local_rank is not None - return DeviceContext( - context=self.context, - n_nodes=self.n_nodes, - gpus_per_node=self.gpus_per_node, - node_rank=node_rank, - local_rank=local_rank, - ) - - @classmethod - def from_opts(cls, opts): - gpus_per_node = len(opts.gpu_ranks) - world_size = int(opts.world_size) if gpus_per_node > 0 else 0 - multinode = gpus_per_node != world_size - if world_size <= 0: - # setting a non-positive world size means use CPU - device_context_enum = DeviceContextEnum.CPU - if opts.n_nodes != 1: - raise ValueError('CPU training is only possible on a single node') - elif world_size == 1: - # world size 1 uses GPU, but is not distributed - device_context_enum = DeviceContextEnum.SINGLE_GPU - if opts.n_nodes != 1: - raise ValueError( - f'Invalid single-gpu node configuration: ' - f'n_nodes {opts.n_nodes} gpus_per_node {gpus_per_node} world_size {world_size}' - ) - else: - # world size > 1 - if multinode and opts.n_nodes == 1: - raise ValueError( - f'Invalid multi-node configuration: ' - f'n_nodes {opts.n_nodes} gpus_per_node {gpus_per_node} world_size {world_size}' - ) - device_context_enum = DeviceContextEnum.MULTI_GPU - world_context = WorldContext(context=device_context_enum, n_nodes=opts.n_nodes, gpus_per_node=gpus_per_node) - return world_context - - -@dataclass -class DeviceContext(WorldContext): - # Our place in the world - node_rank: int - local_rank: int - - @property - def global_rank(self) -> int: - return self.gpus_per_node * self.node_rank + self.local_rank - - @property - def id(self) -> str: - if self.is_gpu(): - return f'GPU {self.node_rank}:{self.local_rank}' - else: - return 'CPU' - - def validate(self, world_context): - # check that this DeviceContext is consistent with given WorldContext - assert self.context == world_context.context - assert self.n_nodes == world_context.n_nodes - assert self.gpus_per_node == world_context.gpus_per_node - # check that ranks are within the specified size of the world - assert 0 <= self.node_rank < self.n_nodes - if self.is_gpu(): - assert 0 <= self.local_rank < self.gpus_per_node diff --git a/mammoth/models/__init__.py b/mammoth/models/__init__.py deleted file mode 100644 index 30263cd6..00000000 --- a/mammoth/models/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Module defining models.""" -from mammoth.models.model_saver import build_model_saver, ModelSaver -from mammoth.models.model import NMTModel - -__all__ = ["build_model_saver", "ModelSaver", "NMTModel"] diff --git a/mammoth/modules/__init__.py b/mammoth/modules/__init__.py deleted file mode 100644 index 975b2ef6..00000000 --- a/mammoth/modules/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Components libary""" -from mammoth.modules.util_class import Elementwise -from mammoth.modules.multi_headed_attn import MultiHeadedAttention -from mammoth.modules.embeddings import Embeddings, PositionalEncoding -# from mammoth.modules.weight_norm import WeightNormConv2d -from mammoth.modules.average_attn import AverageAttention -from mammoth.modules.attention_bridge import AttentionBridge - -from mammoth.modules.encoder import EncoderBase -from mammoth.modules.transformer_encoder import TransformerEncoder -from mammoth.modules.mean_encoder import MeanEncoder - -from mammoth.modules.decoder import DecoderBase -from mammoth.modules.transformer_decoder import TransformerDecoder - - -str2enc = { - "transformer": TransformerEncoder, - "mean": MeanEncoder, -} - -str2dec = { - "transformer": TransformerDecoder, -} - - -__all__ = [ - "DecoderBase", - "TransformerDecoder", - "str2dec", - "EncoderBase", - "TransformerEncoder", - "MeanEncoder", - "str2enc", - "Elementwise", - "MultiHeadedAttention", - "Embeddings", - "PositionalEncoding", - "AverageAttention", - "AttentionBridge", -] diff --git a/mammoth/modules/decoder.py b/mammoth/modules/decoder.py deleted file mode 100644 index e0e707f5..00000000 --- a/mammoth/modules/decoder.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch.nn as nn - - -class DecoderBase(nn.Module): - """Abstract class for decoders. - - Args: - attentional (bool): The decoder returns non-empty attention. - """ - - def __init__(self, attentional=True): - super(DecoderBase, self).__init__() - self.attentional = attentional - - @classmethod - def from_opts(cls, opts, embeddings): - """Alternate constructor. - - Subclasses should override this method. - """ - - raise NotImplementedError diff --git a/mammoth/translate/__init__.py b/mammoth/translate/__init__.py deleted file mode 100644 index a48ea841..00000000 --- a/mammoth/translate/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -""" Modules for translation """ -from mammoth.translate.translator import Translator, GeneratorLM -from mammoth.translate.translation import Translation, TranslationBuilder -from mammoth.translate.beam_search import BeamSearch, GNMTGlobalScorer -from mammoth.translate.beam_search import BeamSearchLM -from mammoth.translate.decode_strategy import DecodeStrategy -from mammoth.translate.greedy_search import GreedySearch, GreedySearchLM -from mammoth.translate.penalties import PenaltyBuilder -from mammoth.translate.translation_server import TranslationServer, ServerModelError - -__all__ = [ - 'Translator', - 'Translation', - 'BeamSearch', - 'GNMTGlobalScorer', - 'TranslationBuilder', - 'PenaltyBuilder', - 'TranslationServer', - 'ServerModelError', - "DecodeStrategy", - "GreedySearch", - "GreedySearchLM", - "BeamSearchLM", - "GeneratorLM", -] diff --git a/mammoth/utils/__init__.py b/mammoth/utils/__init__.py deleted file mode 100644 index 49933156..00000000 --- a/mammoth/utils/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Module defining various utilities.""" -from mammoth.utils.misc import split_corpus, aeq, use_gpu, set_random_seed -from mammoth.utils.alignment import make_batch_align_matrix -from mammoth.utils.report_manager import ReportMgr, build_report_manager -from mammoth.utils.statistics import Statistics -from mammoth.utils.optimizers import MultipleOptimizer, Optimizer, AdaFactorFairSeq -from mammoth.utils.earlystopping import EarlyStopping, scorers_from_opts -from mammoth.utils.loss import build_loss_compute - -__all__ = [ - "split_corpus", - "aeq", - "use_gpu", - "set_random_seed", - "ReportMgr", - "build_report_manager", - "Statistics", - "MultipleOptimizer", - "Optimizer", - "AdaFactorFairSeq", - "EarlyStopping", - "scorers_from_opts", - "make_batch_align_matrix", - "build_loss_compute", -] diff --git a/onmt/__init__.py b/onmt/__init__.py new file mode 100644 index 00000000..78a71d74 --- /dev/null +++ b/onmt/__init__.py @@ -0,0 +1,27 @@ +""" Main entry point of the ONMT library """ +import onmt.inputters +import onmt.encoders +import onmt.decoders +import onmt.models +import onmt.utils +import onmt.modules +import onmt.opts +from onmt.trainer import Trainer +import sys +import onmt.utils.optimizers + +onmt.utils.optimizers.Optim = onmt.utils.optimizers.Optimizer +sys.modules["onmt.Optim"] = onmt.utils.optimizers + +__all__ = [ + onmt.inputters, + onmt.encoders, + onmt.decoders, + onmt.models, + onmt.utils, + onmt.modules, + onmt.opts, + "Trainer" +] + +__version__ = "2.2.0" diff --git a/mammoth/modules/attention_bridge.py b/onmt/attention_bridge.py similarity index 88% rename from mammoth/modules/attention_bridge.py rename to onmt/attention_bridge.py index d6fa0592..481c6f8e 100644 --- a/mammoth/modules/attention_bridge.py +++ b/onmt/attention_bridge.py @@ -4,10 +4,10 @@ import torch import torch.nn as nn -from mammoth.rmsnorm_torch import RMSNorm -from mammoth.modules.transformer_encoder import TransformerEncoderLayer +from onmt.rmsnorm_torch import RMSNorm +from onmt.encoders.transformer import TransformerEncoderLayer -from mammoth.modules.multi_headed_attn import MultiHeadedAttention +from onmt.modules.multi_headed_attn import MultiHeadedAttention class BaseAttentionBridgeLayer(nn.Module): @@ -73,15 +73,15 @@ def __init__( self.self_ff_norm = AttentionBridgeNorm(latent_size, norm_type) @classmethod - def from_opts(cls, opts): + def from_opt(cls, opt): return cls( - opts.model_dim, - opts.hidden_ab_size, - opts.ab_fixed_length, - opts.heads, - opts.attention_dropout[0], - opts.max_relative_positions, - opts.ab_layer_norm, + opt.rnn_size, + opt.hidden_ab_size, + opt.ab_fixed_length, + opt.heads, + opt.attention_dropout[0], + opt.max_relative_positions, + opt.ab_layer_norm, ) @property @@ -133,7 +133,7 @@ def __init__( attention_heads, hidden_ab_size, model_type, - model_dim, + dec_rnn_size, ab_layer_norm=None, ): """Attention Heads Layer:""" @@ -144,7 +144,7 @@ def __init__( self.dd = u self.model_type = model_type if self.model_type != "text": - d = model_dim + d = dec_rnn_size self.ws1 = nn.Linear(d, u, bias=True) self.ws2 = nn.Linear(u, r, bias=True) self.relu = nn.ReLU() @@ -154,15 +154,15 @@ def __init__( self.norm = AttentionBridgeNorm(d, ab_layer_norm) @classmethod - def from_opts(cls, opts): + def from_opt(cls, opt): """Alternate constructor.""" return cls( - opts.model_dim, - opts.ab_fixed_length, - opts.hidden_ab_size, - opts.model_type, - opts.model_dim, - opts.ab_layer_norm, + opt.rnn_size, + opt.ab_fixed_length, + opt.hidden_ab_size, + opt.model_type, + opt.dec_rnn_size, + opt.ab_layer_norm, ) def forward(self, intermediate_output, encoder_output, mask=None): @@ -244,12 +244,12 @@ def forward(self, intermediate_output, encoder_output, mask=None): return attention_weights, output @classmethod - def from_opts(cls, opts): + def from_opt(cls, opt): return cls( - opts.model_dim, - opts.hidden_ab_size, - opts.ab_fixed_length, - opts.ab_layer_norm, + opt.enc_rnn_size, + opt.hidden_ab_size, + opt.ab_fixed_length, + opt.ab_layer_norm, ) @@ -276,15 +276,15 @@ def forward(self, intermediate_output, encoder_output, mask=None): return None, outp @classmethod - def from_opts(cls, opts): + def from_opt(cls, opt): return cls( - opts.model_dim, - opts.heads, - opts.hidden_ab_size, # d_ff + opt.enc_rnn_size, + opt.heads, + opt.hidden_ab_size, # d_ff # TODO: that list indexing things seems suspicious to me... - opts.dropout[0], - opts.attention_dropout[0], - max_relative_positions=opts.max_relative_positions, + opt.dropout[0], + opt.attention_dropout[0], + max_relative_positions=opt.max_relative_positions, ) @@ -313,11 +313,11 @@ def forward(self, intermediate_output, encoder_output, mask=None): return None, self.module(intermediate_output) @classmethod - def from_opts(cls, opts): + def from_opt(cls, opt): return cls( - opts.model_dim, - opts.hidden_ab_size, - opts.ab_layer_norm, + opt.enc_rnn_size, + opt.hidden_ab_size, + opt.ab_layer_norm, ) @@ -333,7 +333,7 @@ def __init__(self, layers): self.is_fixed_length = any(x.is_fixed_length for x in layers) @classmethod - def from_opts(cls, opts): + def from_opt(cls, opt): """Alternate constructor.""" # convert opts specifications to architectures layer_type_to_cls = { @@ -344,16 +344,16 @@ def from_opts(cls, opts): 'feedforward': FeedForwardAttentionBridgeLayer, } - # preconstruct layers using .from_opts(...) - layers = [layer_type_to_cls[layer_type].from_opts(opts) for layer_type in opts.ab_layers] + # preconstruct layers using .from_opt(...) + layers = [layer_type_to_cls[layer_type].from_opt(opt) for layer_type in opt.ab_layers] # FIXME: locking-in edge case behavior - if any(layer == 'perceiver' for layer in opts.ab_layers): - first_perceiver_index = next(idx for idx, layer in enumerate(opts.ab_layers) if layer == 'perceiver') + if any(layer == 'perceiver' for layer in opt.ab_layers): + first_perceiver_index = next(idx for idx, layer in enumerate(opt.ab_layers) if layer == 'perceiver') if first_perceiver_index != 0: assert any(layer.is_fixed_length for layer in layers[:first_perceiver_index]), \ 'Unsupported bridge configuration: at least one layer must be fixed-size before perceiver' - if not all(layer == 'perceiver' for layer in opts.ab_layers): + if not all(layer == 'perceiver' for layer in opt.ab_layers): warnings.warn('Architecture-mixing not fully supported with perceiver.') # FIXME: deleting unused params manually for perceiver_layer in layers[1:]: diff --git a/mammoth/bin/__init__.py b/onmt/bin/__init__.py similarity index 100% rename from mammoth/bin/__init__.py rename to onmt/bin/__init__.py diff --git a/mammoth/bin/average_models.py b/onmt/bin/average_models.py similarity index 81% rename from mammoth/bin/average_models.py rename to onmt/bin/average_models.py index 417b2c6c..d9c09875 100755 --- a/mammoth/bin/average_models.py +++ b/onmt/bin/average_models.py @@ -5,7 +5,7 @@ def average_models(model_files, fp32=False): vocab = None - opts = None + opt = None avg_model = None avg_generator = None @@ -21,7 +21,7 @@ def average_models(model_files, fp32=False): generator_weights[k] = v.float() if i == 0: - vocab, opts = m['vocab'], m['opts'] + vocab, opt = m['vocab'], m['opt'] avg_model = model_weights avg_generator = generator_weights else: @@ -31,7 +31,7 @@ def average_models(model_files, fp32=False): for (k, v) in avg_generator.items(): avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1) - final = {"vocab": vocab, "opts": opts, "optim": None, "generator": avg_generator, "model": avg_model} + final = {"vocab": vocab, "opt": opt, "optim": None, "generator": avg_generator, "model": avg_model} return final @@ -40,10 +40,10 @@ def main(): parser.add_argument("-models", "-m", nargs="+", required=True, help="List of models") parser.add_argument("-output", "-o", required=True, help="Output file") parser.add_argument("-fp32", "-f", action="store_true", help="Cast params to float32") - opts = parser.parse_args() + opt = parser.parse_args() - final = average_models(opts.models, opts.fp32) - torch.save(final, opts.output) + final = average_models(opt.models, opt.fp32) + torch.save(final, opt.output) if __name__ == "__main__": diff --git a/mammoth/bin/build_vocab.py b/onmt/bin/build_vocab.py similarity index 89% rename from mammoth/bin/build_vocab.py rename to onmt/bin/build_vocab.py index 65f408b1..77ba3bf3 100644 --- a/mammoth/bin/build_vocab.py +++ b/onmt/bin/build_vocab.py @@ -1,13 +1,13 @@ #!/usr/bin/env python """Get vocabulary coutings from transformed corpora samples.""" from collections import Counter, defaultdict -from mammoth.utils.logging import init_logger -from mammoth.utils.misc import set_random_seed, check_path -from mammoth.utils.parse import ArgumentParser -from mammoth.opts import dynamic_prepare_opts -from mammoth.inputters import build_vocab_counts -from mammoth.transforms import make_transforms, get_transforms_cls -from mammoth.distributed import TaskSpecs +from onmt.utils.logging import init_logger +from onmt.utils.misc import set_random_seed, check_path +from onmt.utils.parse import ArgumentParser +from onmt.opts import dynamic_prepare_opts +from onmt.inputters import build_vocab_counts +from onmt.transforms import make_transforms, get_transforms_cls +from onmt.utils.distributed import TaskSpecs def build_vocab_main(opts): @@ -32,8 +32,8 @@ def build_vocab_main(opts): src_counters_by_lang = defaultdict(Counter) tgt_counters_by_lang = defaultdict(Counter) - for corpus_id in opts.tasks: - lang_pair = opts.tasks[corpus_id]['src_tgt'] + for corpus_id in opts.data: + lang_pair = opts.data[corpus_id]['src_tgt'] src_lang, tgt_lang = lang_pair.split('-') task = TaskSpecs( node_rank=None, @@ -44,7 +44,7 @@ def build_vocab_main(opts): decoder_id=tgt_lang, corpus_id=f'{src_lang}-{tgt_lang}', weight=1, - corpus_opts=dict(), + corpus_opt=dict(), src_vocab=None, tgt_vocab=None, encoder_adapter_ids=None, diff --git a/mammoth/bin/release_model.py b/onmt/bin/release_model.py similarity index 75% rename from mammoth/bin/release_model.py rename to onmt/bin/release_model.py index 354341da..7adcf93c 100755 --- a/mammoth/bin/release_model.py +++ b/onmt/bin/release_model.py @@ -17,19 +17,19 @@ def main(): default=None, help="Quantization type for CT2 model.", ) - opts = parser.parse_args() + opt = parser.parse_args() - model = torch.load(opts.model, map_location=torch.device("cpu")) - if opts.format == "pytorch": + model = torch.load(opt.model, map_location=torch.device("cpu")) + if opt.format == "pytorch": model["optim"] = None - torch.save(model, opts.output) - elif opts.format == "ctranslate2": + torch.save(model, opt.output) + elif opt.format == "ctranslate2": import ctranslate2 if not hasattr(ctranslate2, "__version__"): raise RuntimeError("onmt_release_model script requires ctranslate2 >= 2.0.0") - converter = ctranslate2.converters.OpenNMTPyConverter(opts.model) - converter.convert(opts.output, force=True, quantization=opts.quantization) + converter = ctranslate2.converters.OpenNMTPyConverter(opt.model) + converter.convert(opt.output, force=True, quantization=opt.quantization) if __name__ == "__main__": diff --git a/mammoth/bin/server.py b/onmt/bin/server.py similarity index 97% rename from mammoth/bin/server.py rename to onmt/bin/server.py index 7dae4712..75abbaf4 100755 --- a/mammoth/bin/server.py +++ b/onmt/bin/server.py @@ -3,7 +3,7 @@ from flask import Flask, jsonify, request from waitress import serve -from mammoth.translate import TranslationServer, ServerModelError +from onmt.translate import TranslationServer, ServerModelError import logging from logging.handlers import RotatingFileHandler @@ -50,9 +50,9 @@ def clone_model(model_id): timeout = data['timeout'] del data['timeout'] - opts = data.get('opts', None) + opt = data.get('opt', None) try: - model_id, load_time = translation_server.clone_model(model_id, opts, timeout) + model_id, load_time = translation_server.clone_model(model_id, opt, timeout) except ServerModelError as e: out['status'] = STATUS_ERROR out['error'] = str(e) diff --git a/mammoth/bin/train.py b/onmt/bin/train.py similarity index 62% rename from mammoth/bin/train.py rename to onmt/bin/train.py index 4c1d9072..2a376880 100644 --- a/mammoth/bin/train.py +++ b/onmt/bin/train.py @@ -4,7 +4,7 @@ from functools import partial import os -from mammoth.distributed import ( +from onmt.utils.distributed import ( DeviceContext, DeviceContextEnum, ErrorHandler, @@ -13,79 +13,79 @@ batch_producer, consumer, ) -from mammoth.utils.misc import set_random_seed -# from mammoth.modules.embeddings import prepare_pretrained_embeddings -from mammoth.utils.logging import init_logger, logger - -from mammoth.models.model_saver import load_checkpoint -from mammoth.train_single import main as single_main -from mammoth.inputters import DynamicDatasetIter - -from mammoth.utils.parse import ArgumentParser -from mammoth.opts import train_opts -from mammoth.inputters import get_vocab, DEFAULT_SPECIALS -from mammoth.transforms import get_transforms_cls +from onmt.utils.misc import set_random_seed +# from onmt.modules.embeddings import prepare_pretrained_embeddings +from onmt.utils.logging import init_logger, logger + +from onmt.models.model_saver import load_checkpoint +from onmt.train_single import main as single_main +from onmt.inputters import DynamicDatasetIter + +from onmt.utils.parse import ArgumentParser +from onmt.opts import train_opts +from onmt.inputters import get_vocab, DEFAULT_SPECIALS +from onmt.transforms import get_transforms_cls from collections import OrderedDict -from mammoth.constants import ModelTask +from onmt.constants import ModelTask # Set sharing strategy manually instead of default based on the OS. torch.multiprocessing.set_sharing_strategy('file_system') -# def prepare_fields_transforms(opts): +# def prepare_fields_transforms(opt): # """Prepare or dump fields & transforms before training.""" -# transforms_cls = get_transforms_cls(opts._all_transform) -# specials = get_specials(opts, transforms_cls) +# transforms_cls = get_transforms_cls(opt._all_transform) +# specials = get_specials(opt, transforms_cls) # -# fields = build_dynamic_fields(opts, src_specials=specials['src'], tgt_specials=specials['tgt']) +# fields = build_dynamic_fields(opt, src_specials=specials['src'], tgt_specials=specials['tgt']) # # # maybe prepare pretrained embeddings, if any -# prepare_pretrained_embeddings(opts, fields) +# prepare_pretrained_embeddings(opt, fields) # -# if opts.dump_fields: -# save_fields(fields, opts.save_data, overwrite=opts.overwrite) -# if opts.dump_transforms or opts.n_sample != 0: -# transforms = make_transforms(opts, transforms_cls, fields) -# if opts.dump_transforms: -# save_transforms(transforms, opts.save_data, overwrite=opts.overwrite) -# if opts.n_sample != 0: +# if opt.dump_fields: +# save_fields(fields, opt.save_data, overwrite=opt.overwrite) +# if opt.dump_transforms or opt.n_sample != 0: +# transforms = make_transforms(opt, transforms_cls, fields) +# if opt.dump_transforms: +# save_transforms(transforms, opt.save_data, overwrite=opt.overwrite) +# if opt.n_sample != 0: # logger.warning( -# f"`-n_sample` != 0: Training will not be started. Stop after saving {opts.n_sample} samples/corpus." +# f"`-n_sample` != 0: Training will not be started. Stop after saving {opt.n_sample} samples/corpus." # ) -# save_transformed_sample(opts, transforms, n_sample=opts.n_sample) +# save_transformed_sample(opt, transforms, n_sample=opt.n_sample) # logger.info("Sample saved, please check it before restart training.") # sys.exit() # return fields, transforms_cls # TODO: reimplement save_transformed_sample -def _init_train(opts): +def _init_train(opt): """Common initilization stuff for all training process.""" - ArgumentParser.validate_prepare_opts(opts) + ArgumentParser.validate_prepare_opts(opt) - if opts.train_from: + if opt.train_from: # Load checkpoint if we resume from a previous training. - checkpoint = load_checkpoint(ckpt_path=opts.train_from) - # fields = load_fields(opts.save_data, checkpoint) - transforms_cls = get_transforms_cls(opts._all_transform) + checkpoint = load_checkpoint(ckpt_path=opt.train_from) + # fields = load_fields(opt.save_data, checkpoint) + transforms_cls = get_transforms_cls(opt._all_transform) if ( - hasattr(checkpoint["opts"], '_all_transform') - and len(opts._all_transform.symmetric_difference(checkpoint["opts"]._all_transform)) != 0 + hasattr(checkpoint["opt"], '_all_transform') + and len(opt._all_transform.symmetric_difference(checkpoint["opt"]._all_transform)) != 0 ): _msg = "configured transforms is different from checkpoint:" - new_transf = opts._all_transform.difference(checkpoint["opts"]._all_transform) - old_transf = checkpoint["opts"]._all_transform.difference(opts._all_transform) + new_transf = opt._all_transform.difference(checkpoint["opt"]._all_transform) + old_transf = checkpoint["opt"]._all_transform.difference(opt._all_transform) if len(new_transf) != 0: _msg += f" +{new_transf}" if len(old_transf) != 0: _msg += f" -{old_transf}." logger.warning(_msg) - if opts.update_vocab: + if opt.update_vocab: logger.info("Updating checkpoint vocabulary with new vocabulary") - # fields, transforms_cls = prepare_fields_transforms(opts) + # fields, transforms_cls = prepare_fields_transforms(opt) else: checkpoint = None - # fields, transforms_cls = prepare_fields_transforms(opts) + # fields, transforms_cls = prepare_fields_transforms(opt) # Report src and tgt vocab sizes # for side in ['src', 'tgt']: @@ -100,24 +100,24 @@ def _init_train(opts): return checkpoint, None, transforms_cls -# def init_train_prepare_fields_transforms(opts, vocab_path, side): +# def init_train_prepare_fields_transforms(opt, vocab_path, side): # """Prepare or dump fields & transforms before training.""" # -# fields = None # build_dynamic_fields_langspec(opts, vocab_path, side) -# transforms_cls = get_transforms_cls(opts._all_transform) -# # TODO: maybe prepare pretrained embeddings, if any, with `prepare_pretrained_embeddings(opts, fields)` +# fields = None # build_dynamic_fields_langspec(opt, vocab_path, side) +# transforms_cls = get_transforms_cls(opt._all_transform) +# # TODO: maybe prepare pretrained embeddings, if any, with `prepare_pretrained_embeddings(opt, fields)` # -# # if opts.dump_fields: -# # save_fields(fields, opts.save_data, overwrite=opts.overwrite) -# if opts.dump_transforms or opts.n_sample != 0: -# transforms = make_transforms(opts, transforms_cls, fields) -# if opts.dump_transforms: -# save_transforms(transforms, opts.save_data, overwrite=opts.overwrite) -# if opts.n_sample != 0: +# # if opt.dump_fields: +# # save_fields(fields, opt.save_data, overwrite=opt.overwrite) +# if opt.dump_transforms or opt.n_sample != 0: +# transforms = make_transforms(opt, transforms_cls, fields) +# if opt.dump_transforms: +# save_transforms(transforms, opt.save_data, overwrite=opt.overwrite) +# if opt.n_sample != 0: # logger.warning( -# f"`-n_sample` != 0: Training will not be started. Stop after saving {opts.n_sample} samples/corpus." +# f"`-n_sample` != 0: Training will not be started. Stop after saving {opt.n_sample} samples/corpus." # ) -# save_transformed_sample(opts, transforms, n_sample=opts.n_sample) +# save_transformed_sample(opt, transforms, n_sample=opt.n_sample) # logger.info("Sample saved, please check it before restart training.") # sys.exit() # @@ -127,7 +127,7 @@ def _init_train(opts): # return fields -def validate_slurm_node_opts(current_env, world_context, opts): +def validate_slurm_node_opts(current_env, world_context, opt): """If you are using slurm, confirm that opts match slurm environment variables""" slurm_n_nodes = int(current_env['SLURM_NNODES']) if slurm_n_nodes != world_context.n_nodes: @@ -136,35 +136,35 @@ def validate_slurm_node_opts(current_env, world_context, opts): f'but set n_nodes to {world_context.n_nodes} in the conf' ) slurm_node_id = int(current_env['SLURM_NODEID']) - if slurm_node_id != opts.node_rank: + if slurm_node_id != opt.node_rank: raise ValueError( f'Looks like you are running on slurm node {slurm_node_id}, ' - f'but set node_rank to {opts.node_rank} on the command line' + f'but set node_rank to {opt.node_rank} on the command line' ) -def train(opts): - init_logger(opts.log_file) - ArgumentParser.validate_train_opts(opts) - ArgumentParser.update_model_opts(opts) - ArgumentParser.validate_model_opts(opts) - ArgumentParser.validate_prepare_opts(opts) - set_random_seed(opts.seed, False) +def train(opt): + init_logger(opt.log_file) + ArgumentParser.validate_train_opts(opt) + ArgumentParser.update_model_opts(opt) + ArgumentParser.validate_model_opts(opt) + ArgumentParser.validate_prepare_opts(opt) + set_random_seed(opt.seed, False) # set PyTorch distributed related environment variables current_env = os.environ - current_env["WORLD_SIZE"] = str(opts.world_size) - world_context = WorldContext.from_opts(opts) + current_env["WORLD_SIZE"] = str(opt.world_size) + world_context = WorldContext.from_opt(opt) if 'SLURM_NNODES' in current_env: - validate_slurm_node_opts(current_env, world_context, opts) + validate_slurm_node_opts(current_env, world_context, opt) logger.info(f'Training on {world_context}') - opts.data_task = ModelTask.SEQ2SEQ + opt.data_task = ModelTask.SEQ2SEQ - transforms_cls = get_transforms_cls(opts._all_transform) + transforms_cls = get_transforms_cls(opt._all_transform) if transforms_cls: logger.info(f'All transforms: {transforms_cls}') - src_specials, tgt_specials = zip(*(cls.get_specials(opts) for cls in transforms_cls.values())) + src_specials, tgt_specials = zip(*(cls.get_specials(opt) for cls in transforms_cls.values())) all_specials = set(DEFAULT_SPECIALS) for special_group in src_specials + tgt_specials: all_specials = all_specials | special_group @@ -177,12 +177,12 @@ def train(opts): vocabs_dict = OrderedDict() # For creating fields, we use a task_queue_manager that doesn't filter by node and gpu - global_task_queue_manager = TaskQueueManager.from_opts(opts, world_context) + global_task_queue_manager = TaskQueueManager.from_opt(opt, world_context) - vocab_size = {'src': opts.src_vocab_size or None, 'tgt': opts.tgt_vocab_size or None} + vocab_size = {'src': opt.src_vocab_size or None, 'tgt': opt.tgt_vocab_size or None} for side in ('src', 'tgt'): for lang in global_task_queue_manager.get_langs(side): - vocab_path = opts.__getattribute__(f'{side}_vocab')[lang] + vocab_path = opt.__getattribute__(f'{side}_vocab')[lang] # FIXME: for now, all specials are passed to all vocabs, this could be finer-grained vocabs_dict[(side, lang)] = get_vocab(vocab_path, lang, vocab_size[side], specials=all_specials) # for key, val in fields_dict: @@ -193,14 +193,14 @@ def train(opts): logger.debug(f"[{os.getpid()}] Initializing process group with: {current_env}") if world_context.context == DeviceContextEnum.MULTI_GPU: - current_env["MASTER_ADDR"] = opts.master_ip - current_env["MASTER_PORT"] = str(opts.master_port) - node_rank = opts.node_rank + current_env["MASTER_ADDR"] = opt.master_ip + current_env["MASTER_PORT"] = str(opt.master_port) + node_rank = opt.node_rank queues = [] semaphores = [] mp = torch.multiprocessing.get_context('spawn') - logger.info("world_size = {}, queue_size = {}".format(opts.world_size, opts.queue_size)) + logger.info("world_size = {}, queue_size = {}".format(opt.world_size, opt.queue_size)) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) @@ -217,21 +217,21 @@ def train(opts): task_queue_manager = global_task_queue_manager.global_to_local( node_rank=node_rank, local_rank=local_rank, - opts=opts + opt=opt ) # store rank in env (FIXME: is this obsolete?) current_env["RANK"] = str(device_context.global_rank) current_env["LOCAL_RANK"] = str(device_context.local_rank) - q = mp.Queue(opts.queue_size) - semaphore = mp.Semaphore(opts.queue_size) + q = mp.Queue(opt.queue_size) + semaphore = mp.Semaphore(opt.queue_size) queues.append(q) semaphores.append(semaphore) procs.append( mp.Process( target=consumer, - args=(train_process, opts, device_context, error_queue, q, semaphore, task_queue_manager), + args=(train_process, opt, device_context, error_queue, q, semaphore, task_queue_manager), daemon=True, ) ) @@ -244,12 +244,12 @@ def train(opts): task_queue_manager=task_queue_manager, transforms_cls=transforms_cls, vocabs_dict=vocabs_dict, - opts=opts, + opts=opt, is_train=True, ) producer = mp.Process( - target=batch_producer, args=(train_iter, q, semaphore, opts, local_rank), daemon=True + target=batch_producer, args=(train_iter, q, semaphore, opt, local_rank), daemon=True ) producers.append(producer) producers[local_rank].start() @@ -272,9 +272,9 @@ def train(opts): task_queue_manager = global_task_queue_manager.global_to_local( node_rank=0, local_rank=0, - opts=opts + opt=opt ) - train_process(opts, device_context=device_context, task_queue_manager=task_queue_manager) + train_process(opt, device_context=device_context, task_queue_manager=task_queue_manager) def _get_parser(): @@ -286,8 +286,8 @@ def _get_parser(): def main(): parser = _get_parser() - opts, unknown = parser.parse_known_args() - train(opts) + opt, unknown = parser.parse_known_args() + train(opt) if __name__ == "__main__": diff --git a/onmt/bin/translate.py b/onmt/bin/translate.py new file mode 100644 index 00000000..99ab0a82 --- /dev/null +++ b/onmt/bin/translate.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from onmt.utils.logging import init_logger +from onmt.utils.misc import split_corpus +from onmt.translate.translator import build_translator +# from onmt.inputters.text_dataset import InferenceDataReader +from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe + +import onmt.opts as opts +from onmt.utils.distributed import TaskSpecs +from onmt.utils.parse import ArgumentParser + + +def translate(opt): + ArgumentParser.validate_translate_opts(opt) + ArgumentParser._get_all_transform_translate(opt) + ArgumentParser._validate_transforms_opts(opt) + ArgumentParser.validate_translate_opts_dynamic(opt) + logger = init_logger(opt.log_file) + + encoder_adapter_ids = set() + for layer_stack_idx, stack in enumerate(opt.stack['encoder']): + if 'adapters' in stack: + for group_id, sub_id in stack['adapters']: + encoder_adapter_ids.add((layer_stack_idx, group_id, sub_id)) + decoder_adapter_ids = set() + for layer_stack_idx, stack in enumerate(opt.stack['decoder']): + if 'adapters' in stack: + for group_id, sub_id in stack['adapters']: + decoder_adapter_ids.add((layer_stack_idx, group_id, sub_id)) + + logger.info( + 'It is ok that src_vocab and tgt_vocab are None here. ' + 'The vocabs are separately loaded in model_builder.' + ) + task = TaskSpecs( + node_rank=None, + local_rank=None, + src_lang=opt.src_lang, + tgt_lang=opt.tgt_lang, + encoder_id=[stack['id'] for stack in opt.stack['encoder']], + decoder_id=[stack['id'] for stack in opt.stack['decoder']], + corpus_id='trans', + weight=1, + corpus_opt=dict(), + src_vocab=None, + tgt_vocab=None, + encoder_adapter_ids=encoder_adapter_ids, + decoder_adapter_ids=decoder_adapter_ids, + ) + + translator = build_translator(opt, task, logger=logger, report_score=True) + + # data_reader = InferenceDataReader(opt.src, opt.tgt, opt.src_feats) + src_shards = split_corpus(opt.src, opt.shard_size) + tgt_shards = split_corpus(opt.tgt, opt.shard_size) + features_shards = [] + features_names = [] + for feat_name, feat_path in opt.src_feats.items(): + features_shards.append(split_corpus(feat_path, opt.shard_size)) + features_names.append(feat_name) + shard_pairs = zip(src_shards, tgt_shards, *features_shards) + + # Build transforms + transforms_cls = get_transforms_cls(opt._all_transform) + transforms = make_transforms(opt, transforms_cls, translator.vocabs, task=task) + data_transform = [ + transforms[name] for name in opt.transforms if name in transforms + ] + transform = TransformPipe.build_from(data_transform) + + for i, (src_shard, tgt_shard, *feats_shard) in enumerate(shard_pairs): + logger.info("Translating shard %d." % i) + translator.translate_dynamic( + src=src_shard, + transform=transform, + # src_feats=feats_shard, # TODO: put me back in + tgt=tgt_shard, + batch_size=opt.batch_size, + batch_type=opt.batch_type, + attn_debug=opt.attn_debug, + align_debug=opt.align_debug + ) + + +def _get_parser(): + parser = ArgumentParser(description='translate.py') + + opts.config_opts(parser) + opts.translate_opts(parser, dynamic=True) + opts.build_bilingual_model(parser) + return parser + + +def main(): + parser = _get_parser() + + opt = parser.parse_args() + translate(opt) + + +if __name__ == "__main__": + main() diff --git a/mammoth/constants.py b/onmt/constants.py similarity index 100% rename from mammoth/constants.py rename to onmt/constants.py diff --git a/onmt/decoders/__init__.py b/onmt/decoders/__init__.py new file mode 100644 index 00000000..ab6262e3 --- /dev/null +++ b/onmt/decoders/__init__.py @@ -0,0 +1,21 @@ +"""Module defining decoders.""" +from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, StdRNNDecoder +from onmt.decoders.transformer import TransformerDecoder +from onmt.decoders.cnn_decoder import CNNDecoder + + +str2dec = { + "rnn": StdRNNDecoder, + "ifrnn": InputFeedRNNDecoder, + "cnn": CNNDecoder, + "transformer": TransformerDecoder, +} + +__all__ = [ + "DecoderBase", + "TransformerDecoder", + "StdRNNDecoder", + "CNNDecoder", + "InputFeedRNNDecoder", + "str2dec", +] diff --git a/onmt/decoders/cnn_decoder.py b/onmt/decoders/cnn_decoder.py new file mode 100644 index 00000000..5a82f261 --- /dev/null +++ b/onmt/decoders/cnn_decoder.py @@ -0,0 +1,128 @@ +"""Implementation of the CNN Decoder part of +"Convolutional Sequence to Sequence Learning" +""" +import torch +import torch.nn as nn + +from onmt.modules import ConvMultiStepAttention, GlobalAttention +from onmt.utils.cnn_factory import shape_transform, GatedConv +from onmt.decoders.decoder import DecoderBase + +SCALE_WEIGHT = 0.5**0.5 + + +class CNNDecoder(DecoderBase): + """Decoder based on "Convolutional Sequence to Sequence Learning" + :cite:`DBLP:journals/corr/GehringAGYD17`. + + Consists of residual convolutional layers, with ConvMultiStepAttention. + """ + + def __init__( + self, num_layers, hidden_size, attn_type, copy_attn, cnn_kernel_width, dropout, embeddings, copy_attn_type + ): + super(CNNDecoder, self).__init__() + + self.cnn_kernel_width = cnn_kernel_width + self.embeddings = embeddings + + # Decoder State + self.state = {} + + input_size = self.embeddings.embedding_size + self.linear = nn.Linear(input_size, hidden_size) + self.conv_layers = nn.ModuleList( + [GatedConv(hidden_size, cnn_kernel_width, dropout, True) for i in range(num_layers)] + ) + self.attn_layers = nn.ModuleList([ConvMultiStepAttention(hidden_size) for i in range(num_layers)]) + + # CNNDecoder has its own attention mechanism. + # Set up a separate copy attention layer if needed. + assert not copy_attn, "Copy mechanism not yet tested in conv2conv" + if copy_attn: + self.copy_attn = GlobalAttention(hidden_size, attn_type=copy_attn_type) + else: + self.copy_attn = None + + @classmethod + def from_opt(cls, opt, embeddings): + """Alternate constructor.""" + return cls( + opt.dec_layers, + opt.dec_rnn_size, + opt.global_attention, + opt.copy_attn, + opt.cnn_kernel_width, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + embeddings, + opt.copy_attn_type, + ) + + def init_state(self, _, memory_bank, enc_hidden): + """Init decoder state.""" + self.state["src"] = (memory_bank + enc_hidden) * SCALE_WEIGHT + self.state["previous_input"] = None + + def map_state(self, fn): + self.state["src"] = fn(self.state["src"], 1) + if self.state["previous_input"] is not None: + self.state["previous_input"] = fn(self.state["previous_input"], 1) + + def detach_state(self): + self.state["previous_input"] = self.state["previous_input"].detach() + + def forward(self, tgt, memory_bank, step=None, **kwargs): + """See :obj:`onmt.modules.RNNDecoderBase.forward()`""" + + if self.state["previous_input"] is not None: + tgt = torch.cat([self.state["previous_input"], tgt], 0) + + dec_outs = [] + attns = {"std": []} + if self.copy_attn is not None: + attns["copy"] = [] + + emb = self.embeddings(tgt) + assert emb.dim() == 3 # len x batch x embedding_dim + + tgt_emb = emb.transpose(0, 1).contiguous() + # The output of CNNEncoder. + src_memory_bank_t = memory_bank.transpose(0, 1).contiguous() + # The combination of output of CNNEncoder and source embeddings. + src_memory_bank_c = self.state["src"].transpose(0, 1).contiguous() + + emb_reshape = tgt_emb.contiguous().view(tgt_emb.size(0) * tgt_emb.size(1), -1) + linear_out = self.linear(emb_reshape) + x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1) + x = shape_transform(x) + + pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1) + + pad = pad.type_as(x) + base_target_emb = x + + for conv, attention in zip(self.conv_layers, self.attn_layers): + new_target_input = torch.cat([pad, x], 2) + out = conv(new_target_input) + c, attn = attention(base_target_emb, out, src_memory_bank_t, src_memory_bank_c) + x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT + output = x.squeeze(3).transpose(1, 2) + + # Process the result and update the attentions. + dec_outs = output.transpose(0, 1).contiguous() + if self.state["previous_input"] is not None: + dec_outs = dec_outs[self.state["previous_input"].size(0):] + attn = attn[:, self.state["previous_input"].size(0):].squeeze() + attn = torch.stack([attn]) + attns["std"] = attn + if self.copy_attn is not None: + attns["copy"] = attn + + # Update the state. + self.state["previous_input"] = tgt + # TODO change the way attns is returned dict => list or tuple (onnx) + return dec_outs, attns + + def update_dropout(self, dropout): + for layer in self.conv_layers: + layer.dropout.p = dropout diff --git a/onmt/decoders/decoder.py b/onmt/decoders/decoder.py new file mode 100644 index 00000000..b5bdd516 --- /dev/null +++ b/onmt/decoders/decoder.py @@ -0,0 +1,428 @@ +import torch +import torch.nn as nn + +from onmt.models.stacked_rnn import StackedLSTM, StackedGRU +from onmt.modules import context_gate_factory, GlobalAttention +from onmt.utils.rnn_factory import rnn_factory + +from onmt.utils.misc import aeq + + +class DecoderBase(nn.Module): + """Abstract class for decoders. + + Args: + attentional (bool): The decoder returns non-empty attention. + """ + + def __init__(self, attentional=True): + super(DecoderBase, self).__init__() + self.attentional = attentional + + @classmethod + def from_opt(cls, opt, embeddings): + """Alternate constructor. + + Subclasses should override this method. + """ + + raise NotImplementedError + + +class RNNDecoderBase(DecoderBase): + """Base recurrent attention-based decoder class. + + Specifies the interface used by different decoder types + and required by :class:`~onmt.models.NMTModel`. + + + .. mermaid:: + + graph BT + A[Input] + subgraph RNN + C[Pos 1] + D[Pos 2] + E[Pos N] + end + G[Decoder State] + H[Decoder State] + I[Outputs] + F[memory_bank] + A--emb-->C + A--emb-->D + A--emb-->E + H-->C + C-- attn --- F + D-- attn --- F + E-- attn --- F + C-->I + D-->I + E-->I + E-->G + F---I + + Args: + rnn_type (str): + style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] + bidirectional_encoder (bool) : use with a bidirectional encoder + num_layers (int) : number of stacked layers + hidden_size (int) : hidden size of each layer + attn_type (str) : see :class:`~onmt.modules.GlobalAttention` + attn_func (str) : see :class:`~onmt.modules.GlobalAttention` + coverage_attn (str): see :class:`~onmt.modules.GlobalAttention` + context_gate (str): see :class:`~onmt.modules.ContextGate` + copy_attn (bool): setup a separate copy attention mechanism + dropout (float) : dropout value for :class:`torch.nn.Dropout` + embeddings (onmt.modules.Embeddings): embedding module to use + reuse_copy_attn (bool): reuse the attention for copying + copy_attn_type (str): The copy attention style. See + :class:`~onmt.modules.GlobalAttention`. + """ + + def __init__( + self, + rnn_type, + bidirectional_encoder, + num_layers, + hidden_size, + attn_type="general", + attn_func="softmax", + coverage_attn=False, + context_gate=None, + copy_attn=False, + dropout=0.0, + embeddings=None, + reuse_copy_attn=False, + copy_attn_type="general", + ): + super(RNNDecoderBase, self).__init__(attentional=attn_type != "none" and attn_type is not None) + + self.bidirectional_encoder = bidirectional_encoder + self.num_layers = num_layers + self.hidden_size = hidden_size + self.embeddings = embeddings + self.dropout = nn.Dropout(dropout) + + # Decoder state + self.state = {} + + # Build the RNN. + self.rnn = self._build_rnn( + rnn_type, input_size=self._input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout + ) + + # Set up the context gate. + self.context_gate = None + if context_gate is not None: + self.context_gate = context_gate_factory( + context_gate, self._input_size, hidden_size, hidden_size, hidden_size + ) + + # Set up the standard attention. + self._coverage = coverage_attn + if not self.attentional: + if self._coverage: + raise ValueError("Cannot use coverage term with no attention.") + self.attn = None + else: + self.attn = GlobalAttention(hidden_size, coverage=coverage_attn, attn_type=attn_type, attn_func=attn_func) + + if copy_attn and not reuse_copy_attn: + if copy_attn_type == "none" or copy_attn_type is None: + raise ValueError("Cannot use copy_attn with copy_attn_type none") + self.copy_attn = GlobalAttention(hidden_size, attn_type=copy_attn_type, attn_func=attn_func) + else: + self.copy_attn = None + + self._reuse_copy_attn = reuse_copy_attn and copy_attn + if self._reuse_copy_attn and not self.attentional: + raise ValueError("Cannot reuse copy attention with no attention.") + + @classmethod + def from_opt(cls, opt, embeddings): + """Alternate constructor.""" + return cls( + opt.rnn_type, + opt.brnn, + opt.dec_layers, + opt.dec_rnn_size, + opt.global_attention, + opt.global_attention_function, + opt.coverage_attn, + opt.context_gate, + opt.copy_attn, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + embeddings, + opt.reuse_copy_attn, + opt.copy_attn_type, + ) + + def init_state(self, src, memory_bank, encoder_final): + """Initialize decoder state with last state of the encoder.""" + + def _fix_enc_hidden(hidden): + # The encoder hidden is (layers*directions) x batch x dim. + # We need to convert it to layers x batch x (directions*dim). + if self.bidirectional_encoder: + hidden = torch.cat( + [hidden[0:hidden.size(0):2], hidden[1:hidden.size(0):2]], 2 + ) + return hidden + + if isinstance(encoder_final, tuple): # LSTM + self.state["hidden"] = tuple(_fix_enc_hidden(enc_hid) for enc_hid in encoder_final) + else: # GRU + self.state["hidden"] = (_fix_enc_hidden(encoder_final),) + + # Init the input feed. + batch_size = self.state["hidden"][0].size(1) + h_size = (batch_size, self.hidden_size) + self.state["input_feed"] = self.state["hidden"][0].data.new(*h_size).zero_().unsqueeze(0) + self.state["coverage"] = None + + def map_state(self, fn): + self.state["hidden"] = tuple(fn(h, 1) for h in self.state["hidden"]) + self.state["input_feed"] = fn(self.state["input_feed"], 1) + if self._coverage and self.state["coverage"] is not None: + self.state["coverage"] = fn(self.state["coverage"], 1) + + def detach_state(self): + self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"]) + self.state["input_feed"] = self.state["input_feed"].detach() + + def forward(self, tgt, memory_bank, memory_lengths=None, step=None, **kwargs): + """ + Args: + tgt (LongTensor): sequences of padded tokens + ``(tgt_len, batch, nfeats)``. + memory_bank (FloatTensor): vectors from the encoder + ``(src_len, batch, hidden)``. + memory_lengths (LongTensor): the padded source lengths + ``(batch,)``. + + Returns: + (FloatTensor, dict[str, FloatTensor]): + + * dec_outs: output from the decoder (after attn) + ``(tgt_len, batch, hidden)``. + * attns: distribution over src at each tgt + ``(tgt_len, batch, src_len)``. + """ + + dec_state, dec_outs, attns = self._run_forward_pass(tgt, memory_bank, memory_lengths=memory_lengths) + + # Update the state with the result. + if not isinstance(dec_state, tuple): + dec_state = (dec_state,) + self.state["hidden"] = dec_state + self.state["input_feed"] = dec_outs[-1].unsqueeze(0) + self.state["coverage"] = None + if "coverage" in attns: + self.state["coverage"] = attns["coverage"][-1].unsqueeze(0) + + # Concatenates sequence of tensors along a new dimension. + # NOTE: v0.3 to 0.4: dec_outs / attns[*] may not be list + # (in particular in case of SRU) it was not raising error in 0.3 + # since stack(Variable) was allowed. + # In 0.4, SRU returns a tensor that shouldn't be stacke + if type(dec_outs) == list: + dec_outs = torch.stack(dec_outs) + + for k in attns: + if type(attns[k]) == list: + attns[k] = torch.stack(attns[k]) + return dec_outs, attns + + def update_dropout(self, dropout): + self.dropout.p = dropout + self.embeddings.update_dropout(dropout) + + +class StdRNNDecoder(RNNDecoderBase): + """Standard fully batched RNN decoder with attention. + + Faster implementation, uses CuDNN for implementation. + See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options. + + + Based around the approach from + "Neural Machine Translation By Jointly Learning To Align and Translate" + :cite:`Bahdanau2015` + + + Implemented without input_feeding and currently with no `coverage_attn` + or `copy_attn` support. + """ + + def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): + """ + Private helper for running the specific RNN forward pass. + Must be overriden by all subclasses. + + Args: + tgt (LongTensor): a sequence of input tokens tensors + ``(len, batch, nfeats)``. + memory_bank (FloatTensor): output(tensor sequence) from the + encoder RNN of size ``(src_len, batch, hidden_size)``. + memory_lengths (LongTensor): the source memory_bank lengths. + + Returns: + (Tensor, List[FloatTensor], Dict[str, List[FloatTensor]): + + * dec_state: final hidden state from the decoder. + * dec_outs: an array of output of every time + step from the decoder. + * attns: a dictionary of different + type of attention Tensor array of every time + step from the decoder. + """ + + assert self.copy_attn is None # TODO, no support yet. + assert not self._coverage # TODO, no support yet. + + attns = {} + emb = self.embeddings(tgt) + + if isinstance(self.rnn, nn.GRU): + rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0]) + else: + rnn_output, dec_state = self.rnn(emb, self.state["hidden"]) + + # Check + tgt_len, tgt_batch, _ = tgt.size() + output_len, output_batch, _ = rnn_output.size() + aeq(tgt_len, output_len) + aeq(tgt_batch, output_batch) + + # Calculate the attention. + if not self.attentional: + dec_outs = rnn_output + else: + dec_outs, p_attn = self.attn( + rnn_output.transpose(0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths + ) + attns["std"] = p_attn + + # Calculate the context gate. + if self.context_gate is not None: + dec_outs = self.context_gate( + emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), dec_outs.view(-1, dec_outs.size(2)) + ) + dec_outs = dec_outs.view(tgt_len, tgt_batch, self.hidden_size) + + dec_outs = self.dropout(dec_outs) + return dec_state, dec_outs, attns + + def _build_rnn(self, rnn_type, **kwargs): + rnn, _ = rnn_factory(rnn_type, **kwargs) + return rnn + + @property + def _input_size(self): + return self.embeddings.embedding_size + + +class InputFeedRNNDecoder(RNNDecoderBase): + """Input feeding based decoder. + + See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options. + + Based around the input feeding approach from + "Effective Approaches to Attention-based Neural Machine Translation" + :cite:`Luong2015` + + + .. mermaid:: + + graph BT + A[Input n-1] + AB[Input n] + subgraph RNN + E[Pos n-1] + F[Pos n] + E --> F + end + G[Encoder] + H[memory_bank n-1] + A --> E + AB --> F + E --> H + G --> H + """ + + def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): + """ + See StdRNNDecoder._run_forward_pass() for description + of arguments and return values. + """ + # Additional args check. + input_feed = self.state["input_feed"].squeeze(0) + input_feed_batch, _ = input_feed.size() + _, tgt_batch, _ = tgt.size() + aeq(tgt_batch, input_feed_batch) + # END Additional args check. + + dec_outs = [] + attns = {} + if self.attn is not None: + attns["std"] = [] + if self.copy_attn is not None or self._reuse_copy_attn: + attns["copy"] = [] + if self._coverage: + attns["coverage"] = [] + + emb = self.embeddings(tgt) + assert emb.dim() == 3 # len x batch x embedding_dim + + dec_state = self.state["hidden"] + coverage = self.state["coverage"].squeeze(0) if self.state["coverage"] is not None else None + + # Input feed concatenates hidden state with + # input at every time step. + for emb_t in emb.split(1): + decoder_input = torch.cat([emb_t.squeeze(0), input_feed], 1) + rnn_output, dec_state = self.rnn(decoder_input, dec_state) + if self.attentional: + decoder_output, p_attn = self.attn( + rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths + ) + attns["std"].append(p_attn) + else: + decoder_output = rnn_output + if self.context_gate is not None: + # TODO: context gate should be employed + # instead of second RNN transform. + decoder_output = self.context_gate(decoder_input, rnn_output, decoder_output) + decoder_output = self.dropout(decoder_output) + input_feed = decoder_output + + dec_outs += [decoder_output] + + # Update the coverage attention. + if self._coverage: + coverage = p_attn if coverage is None else p_attn + coverage + attns["coverage"] += [coverage] + + if self.copy_attn is not None: + _, copy_attn = self.copy_attn(decoder_output, memory_bank.transpose(0, 1)) + attns["copy"] += [copy_attn] + elif self._reuse_copy_attn: + attns["copy"] = attns["std"] + + return dec_state, dec_outs, attns + + def _build_rnn(self, rnn_type, input_size, hidden_size, num_layers, dropout): + assert rnn_type != "SRU", "SRU doesn't support input feed! Please set -input_feed 0!" + stacked_cell = StackedLSTM if rnn_type == "LSTM" else StackedGRU + return stacked_cell(num_layers, input_size, hidden_size, dropout) + + @property + def _input_size(self): + """Using input feed by concatenating input with attention vectors.""" + return self.embeddings.embedding_size + self.hidden_size + + def update_dropout(self, dropout): + self.dropout.p = dropout + self.rnn.dropout.p = dropout + self.embeddings.update_dropout(dropout) diff --git a/mammoth/modules/decoder_ensemble.py b/onmt/decoders/ensemble.py similarity index 89% rename from mammoth/modules/decoder_ensemble.py rename to onmt/decoders/ensemble.py index 08248f1d..c1b2c3b1 100644 --- a/mammoth/modules/decoder_ensemble.py +++ b/onmt/decoders/ensemble.py @@ -9,10 +9,10 @@ import torch import torch.nn as nn -from mammoth.modules.encoder import EncoderBase -from mammoth.modules.decoder import DecoderBase -from mammoth.models import NMTModel -import mammoth.model_builder +from onmt.encoders.encoder import EncoderBase +from onmt.decoders.decoder import DecoderBase +from onmt.models import NMTModel +import onmt.model_builder class EnsembleDecoderOutput(object): @@ -23,7 +23,7 @@ def __init__(self, model_dec_outs): def squeeze(self, dim=None): """Delegate squeeze to avoid modifying - :func:`mammoth.translate.translator.Translator.translate_batch()` + :func:`onmt.translate.translator.Translator.translate_batch()` """ return EnsembleDecoderOutput([x.squeeze(dim) for x in self.model_dec_outs]) @@ -53,7 +53,7 @@ def __init__(self, model_decoders): self.model_decoders = model_decoders def forward(self, tgt, memory_bank, memory_lengths=None, step=None, **kwargs): - """See :func:`mammoth.decoders.decoder.DecoderBase.forward()`.""" + """See :func:`onmt.decoders.decoder.DecoderBase.forward()`.""" # Memory_lengths is a single tensor shared between all models. # This assumption will not hold if Translator is modified # to calculate memory_lengths as something other than the length @@ -120,13 +120,13 @@ def __init__(self, models, raw_probs=False): self.models = nn.ModuleList(models) -def load_test_model(opts): +def load_test_model(opt): """Read in multiple models for ensemble.""" shared_vocabs = None shared_model_opt = None models = [] - for model_path in opts.models: - vocabs, model, model_opts = mammoth.model_builder.load_test_multitask_model(opts, model_path=model_path) + for model_path in opt.models: + vocabs, model, model_opt = onmt.model_builder.load_test_multitask_model(opt, model_path=model_path) if shared_vocabs is None: shared_vocabs = vocabs else: @@ -137,6 +137,6 @@ def load_test_model(opts): # assert vocab.stoi == sh_vocab.stoi, "Ensemble models must use the same preprocessed data" models.append(model) if shared_model_opt is None: - shared_model_opt = model_opts - ensemble_model = EnsembleModel(models, opts.avg_raw_probs) + shared_model_opt = model_opt + ensemble_model = EnsembleModel(models, opt.avg_raw_probs) return shared_vocabs, ensemble_model, shared_model_opt diff --git a/mammoth/modules/layer_stack_decoder.py b/onmt/decoders/layer_stack_decoder.py similarity index 75% rename from mammoth/modules/layer_stack_decoder.py rename to onmt/decoders/layer_stack_decoder.py index a2136889..5fc7f594 100644 --- a/mammoth/modules/layer_stack_decoder.py +++ b/onmt/decoders/layer_stack_decoder.py @@ -2,9 +2,9 @@ from torch import nn from typing import Dict, List -from mammoth.modules.decoder import DecoderBase -from mammoth.models.adapters import Adapter, AdaptedTransformerDecoder -from mammoth.distributed import DatasetMetadata +from onmt.decoders.decoder import DecoderBase +from onmt.models.adapters import Adapter, AdaptedTransformerDecoder +from onmt.utils.distributed import DatasetMetadata class LayerStackDecoder(DecoderBase): @@ -17,11 +17,11 @@ def __init__(self, embeddings, decoders): self._active: List[str] = [] @classmethod - def from_opts(cls, opts, embeddings, task_queue_manager, is_on_top=False): + def from_opt(cls, opt, embeddings, task_queue_manager, is_on_top=False): """Alternate constructor for use during training.""" decoders = nn.ModuleList() - for layer_stack_index, n_layers in enumerate(opts.dec_layers): - is_on_top = layer_stack_index == len(opts.dec_layers) - 1 + for layer_stack_index, n_layers in enumerate(opt.dec_layers): + is_on_top = layer_stack_index == len(opt.dec_layers) - 1 stacks = nn.ModuleDict() for module_id in task_queue_manager.get_decoders(layer_stack_index): if module_id in stacks: @@ -29,26 +29,26 @@ def from_opts(cls, opts, embeddings, task_queue_manager, is_on_top=False): continue stacks[module_id] = AdaptedTransformerDecoder( n_layers, - opts.model_dim, - opts.heads, - opts.transformer_ff, - opts.copy_attn, - opts.self_attn_type, - opts.dropout[0] if type(opts.dropout) is list else opts.dropout, + opt.dec_rnn_size, + opt.heads, + opt.transformer_ff, + opt.copy_attn, + opt.self_attn_type, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, ( - opts.attention_dropout[0] - if type(opts.attention_dropout) is list - else opts.attention_dropout + opt.attention_dropout[0] + if type(opt.attention_dropout) is list + else opt.attention_dropout ), None, # embeddings, - opts.max_relative_positions, - opts.aan_useffn, - opts.full_context_alignment, - opts.alignment_layer, - alignment_heads=opts.alignment_heads, - pos_ffn_activation_fn=opts.pos_ffn_activation_fn, + opt.max_relative_positions, + opt.aan_useffn, + opt.full_context_alignment, + opt.alignment_layer, + alignment_heads=opt.alignment_heads, + pos_ffn_activation_fn=opt.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top + nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top else nn.Identity() ), ) @@ -56,36 +56,36 @@ def from_opts(cls, opts, embeddings, task_queue_manager, is_on_top=False): return cls(embeddings, decoders) @classmethod - def from_trans_opt(cls, model_opts, embeddings, opt_stack): + def from_trans_opt(cls, model_opt, embeddings, opt_stack): """Alternate constructor for use during translation.""" decoders = nn.ModuleList() - for layer_stack_index, n_layers in enumerate(model_opts.dec_layers): + for layer_stack_index, n_layers in enumerate(model_opt.dec_layers): stacks = nn.ModuleDict() - is_on_top = layer_stack_index == len(model_opts.dec_layers) - 1 + is_on_top = layer_stack_index == len(model_opt.dec_layers) - 1 module_opts = opt_stack['decoder'][layer_stack_index] module_id = module_opts['id'] stacks[module_id] = AdaptedTransformerDecoder( n_layers, - model_opts.model_dim, - model_opts.heads, - model_opts.transformer_ff, - model_opts.copy_attn, - model_opts.self_attn_type, - model_opts.dropout[0] if type(model_opts.dropout) is list else model_opts.dropout, + model_opt.dec_rnn_size, + model_opt.heads, + model_opt.transformer_ff, + model_opt.copy_attn, + model_opt.self_attn_type, + model_opt.dropout[0] if type(model_opt.dropout) is list else model_opt.dropout, ( - model_opts.attention_dropout[0] - if type(model_opts.attention_dropout) is list - else model_opts.attention_dropout + model_opt.attention_dropout[0] + if type(model_opt.attention_dropout) is list + else model_opt.attention_dropout ), None, # embeddings, - model_opts.max_relative_positions, - model_opts.aan_useffn, - model_opts.full_context_alignment, - model_opts.alignment_layer, - alignment_heads=model_opts.alignment_heads, - pos_ffn_activation_fn=model_opts.pos_ffn_activation_fn, + model_opt.max_relative_positions, + model_opt.aan_useffn, + model_opt.full_context_alignment, + model_opt.alignment_layer, + alignment_heads=model_opt.alignment_heads, + pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(model_opts.model_dim, eps=1e-6) if is_on_top + nn.LayerNorm(model_opt.dec_rnn_size, eps=1e-6) if is_on_top else nn.Identity() ), ) diff --git a/mammoth/modules/transformer_decoder.py b/onmt/decoders/transformer.py similarity index 94% rename from mammoth/modules/transformer_decoder.py rename to onmt/decoders/transformer.py index a52d7a0a..637db455 100644 --- a/mammoth/modules/transformer_decoder.py +++ b/onmt/decoders/transformer.py @@ -6,11 +6,11 @@ import torch import torch.nn as nn -from mammoth.modules.decoder import DecoderBase -from mammoth.modules import MultiHeadedAttention, AverageAttention -from mammoth.modules.position_ffn import PositionwiseFeedForward -from mammoth.modules.position_ffn import ActivationFunction -from mammoth.utils.misc import sequence_mask +from onmt.decoders.decoder import DecoderBase +from onmt.modules import MultiHeadedAttention, AverageAttention +from onmt.modules.position_ffn import PositionwiseFeedForward +from onmt.modules.position_ffn import ActivationFunction +from onmt.utils.misc import sequence_mask class TransformerDecoderLayerBase(nn.Module): @@ -283,26 +283,26 @@ def __init__(self, d_model, copy_attn, embeddings, alignment_layer, layer_norm_m self.alignment_layer = alignment_layer @classmethod - def from_opts(cls, opts, embeddings, is_on_top=False): + def from_opt(cls, opt, embeddings, is_on_top=False): """Alternate constructor.""" return cls( - opts.dec_layers, - opts.model_dim, - opts.heads, - opts.transformer_ff, - opts.copy_attn, - opts.self_attn_type, - opts.dropout[0] if type(opts.dropout) is list else opts.dropout, - opts.attention_dropout[0] if type(opts.attention_dropout) is list else opts.attention_dropout, + opt.dec_layers, + opt.dec_rnn_size, + opt.heads, + opt.transformer_ff, + opt.copy_attn, + opt.self_attn_type, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout, embeddings, - opts.max_relative_positions, - opts.aan_useffn, - opts.full_context_alignment, - opts.alignment_layer, - alignment_heads=opts.alignment_heads, - pos_ffn_activation_fn=opts.pos_ffn_activation_fn, + opt.max_relative_positions, + opt.aan_useffn, + opt.full_context_alignment, + opt.alignment_layer, + alignment_heads=opt.alignment_heads, + pos_ffn_activation_fn=opt.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top + nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top else nn.Identity() ), ) @@ -366,7 +366,7 @@ class TransformerDecoder(TransformerDecoderBase): self_attn_type (str): type of self-attention scaled-dot, average dropout (float): dropout in residual, self-attn(dot) and feed-forward attention_dropout (float): dropout in context_attn (and self-attn(avg)) - embeddings (mammoth.modules.Embeddings): + embeddings (onmt.modules.Embeddings): embeddings to use, should have positional encodings max_relative_positions (int): Max distance between inputs in relative positions representations diff --git a/onmt/encoders/__init__.py b/onmt/encoders/__init__.py new file mode 100644 index 00000000..7885a187 --- /dev/null +++ b/onmt/encoders/__init__.py @@ -0,0 +1,19 @@ +"""Module defining encoders.""" +from onmt.encoders.encoder import EncoderBase +from onmt.encoders.transformer import TransformerEncoder +from onmt.encoders.ggnn_encoder import GGNNEncoder +from onmt.encoders.rnn_encoder import RNNEncoder +from onmt.encoders.cnn_encoder import CNNEncoder +from onmt.encoders.mean_encoder import MeanEncoder + + +str2enc = { + "ggnn": GGNNEncoder, + "rnn": RNNEncoder, + "brnn": RNNEncoder, + "cnn": CNNEncoder, + "transformer": TransformerEncoder, + "mean": MeanEncoder, +} + +__all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", "MeanEncoder", "str2enc"] diff --git a/onmt/encoders/cnn_encoder.py b/onmt/encoders/cnn_encoder.py new file mode 100644 index 00000000..ffa22a4f --- /dev/null +++ b/onmt/encoders/cnn_encoder.py @@ -0,0 +1,53 @@ +""" +Implementation of "Convolutional Sequence to Sequence Learning" +""" +import torch.nn as nn + +from onmt.encoders.encoder import EncoderBase +from onmt.utils.cnn_factory import shape_transform, StackedCNN + +SCALE_WEIGHT = 0.5**0.5 + + +class CNNEncoder(EncoderBase): + """Encoder based on "Convolutional Sequence to Sequence Learning" + :cite:`DBLP:journals/corr/GehringAGYD17`. + """ + + def __init__(self, num_layers, hidden_size, cnn_kernel_width, dropout, embeddings): + super(CNNEncoder, self).__init__() + + self.embeddings = embeddings + input_size = embeddings.embedding_size + self.linear = nn.Linear(input_size, hidden_size) + self.cnn = StackedCNN(num_layers, hidden_size, cnn_kernel_width, dropout) + + @classmethod + def from_opt(cls, opt, embeddings): + """Alternate constructor.""" + return cls( + opt.enc_layers, + opt.enc_rnn_size, + opt.cnn_kernel_width, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + embeddings, + ) + + def forward(self, input, lengths=None, hidden=None): + """See :class:`onmt.modules.EncoderBase.forward()`""" + self._check_args(input, lengths, hidden) + + emb = self.embeddings(input) + # s_len, batch, emb_dim = emb.size() + + emb = emb.transpose(0, 1).contiguous() + emb_reshape = emb.view(emb.size(0) * emb.size(1), -1) + emb_remap = self.linear(emb_reshape) + emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1) + emb_remap = shape_transform(emb_remap) + out = self.cnn(emb_remap) + + return emb_remap.squeeze(3).transpose(0, 1).contiguous(), out.squeeze(3).transpose(0, 1).contiguous(), lengths + + def update_dropout(self, dropout): + self.cnn.dropout.p = dropout diff --git a/mammoth/modules/encoder.py b/onmt/encoders/encoder.py similarity index 90% rename from mammoth/modules/encoder.py rename to onmt/encoders/encoder.py index 9e55792b..71db9cad 100644 --- a/mammoth/modules/encoder.py +++ b/onmt/encoders/encoder.py @@ -2,13 +2,13 @@ import torch.nn as nn -from mammoth.utils.misc import aeq +from onmt.utils.misc import aeq class EncoderBase(nn.Module): """ Base encoder class. Specifies the interface used by different encoder types - and required by :class:`mammoth.Models.NMTModel`. + and required by :class:`onmt.Models.NMTModel`. .. mermaid:: @@ -31,7 +31,7 @@ class EncoderBase(nn.Module): """ @classmethod - def from_opts(cls, opts, embeddings=None): + def from_opt(cls, opt, embeddings=None): raise NotImplementedError def _check_args(self, src, lengths=None, hidden=None): diff --git a/onmt/encoders/ggnn_encoder.py b/onmt/encoders/ggnn_encoder.py new file mode 100644 index 00000000..209b00ab --- /dev/null +++ b/onmt/encoders/ggnn_encoder.py @@ -0,0 +1,311 @@ +"""Define GGNN-based encoders.""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from onmt.encoders.encoder import EncoderBase + + +class GGNNAttrProxy(object): + """ + Translates index lookups into attribute lookups. + To implement some trick which able to use list of nn.Module in a nn.Module + see https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219/2 + """ + + def __init__(self, module, prefix): + self.module = module + self.prefix = prefix + + def __getitem__(self, i): + return getattr(self.module, self.prefix + str(i)) + + +class GGNNPropogator(nn.Module): + """ + Gated Propogator for GGNN + Using LSTM gating mechanism + """ + + def __init__(self, state_dim, n_node, n_edge_types): + super(GGNNPropogator, self).__init__() + + self.n_node = n_node + self.n_edge_types = n_edge_types + + self.reset_gate = nn.Sequential(nn.Linear(state_dim * 3, state_dim), nn.Sigmoid()) + self.update_gate = nn.Sequential(nn.Linear(state_dim * 3, state_dim), nn.Sigmoid()) + self.tansform = nn.Sequential(nn.Linear(state_dim * 3, state_dim), nn.LeakyReLU()) + + def forward(self, state_in, state_out, state_cur, edges, nodes): + edges_in = edges[:, :, : nodes * self.n_edge_types] + edges_out = edges[:, :, nodes * self.n_edge_types:] + + a_in = torch.bmm(edges_in, state_in) + a_out = torch.bmm(edges_out, state_out) + a = torch.cat((a_in, a_out, state_cur), 2) + + r = self.reset_gate(a) + z = self.update_gate(a) + joined_input = torch.cat((a_in, a_out, r * state_cur), 2) + h_hat = self.tansform(joined_input) + + output = (1 - z) * state_cur + z * h_hat + + return output + + +class GGNNEncoder(EncoderBase): + """A gated graph neural network configured as an encoder. + Based on github.com/JamesChuanggg/ggnn.pytorch.git, + which is based on the paper "Gated Graph Sequence Neural Networks" + by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel. + + Args: + rnn_type (str): + style of recurrent unit to use, one of [LSTM] + src_ggnn_size (int) : Size of token-to-node embedding input + src_word_vec_size (int) : Size of token-to-node embedding output + state_dim (int) : Number of state dimensions in nodes + n_edge_types (int) : Number of edge types + bidir_edges (bool): True if reverse edges should be autocreated + n_node (int) : Max nodes in graph + bridge_extra_node (bool): True indicates only 1st extra node + (after token listing) should be used for decoder init. + n_steps (int): Steps to advance graph encoder for stabilization + src_vocab (int): Path to source vocabulary.(The ggnn uses src_vocab + during training because the graph is built using edge information + which requires parsing the input sequence.) + """ + + def __init__( + self, + rnn_type, + src_word_vec_size, + src_ggnn_size, + state_dim, + bidir_edges, + n_edge_types, + n_node, + bridge_extra_node, + n_steps, + src_vocab, + ): + super(GGNNEncoder, self).__init__() + + self.src_word_vec_size = src_word_vec_size + self.src_ggnn_size = src_ggnn_size + self.state_dim = state_dim + self.n_edge_types = n_edge_types + self.n_node = n_node + self.n_steps = n_steps + self.bidir_edges = bidir_edges + self.bridge_extra_node = bridge_extra_node + + for i in range(self.n_edge_types): + # incoming and outgoing edge embedding + in_fc = nn.Linear(self.state_dim, self.state_dim) + out_fc = nn.Linear(self.state_dim, self.state_dim) + self.add_module("in_{}".format(i), in_fc) + self.add_module("out_{}".format(i), out_fc) + + self.in_fcs = GGNNAttrProxy(self, "in_") + self.out_fcs = GGNNAttrProxy(self, "out_") + + # Find vocab data for tree builting + f = open(src_vocab, "r") + idx = 0 + self.COMMA = -1 + self.DELIMITER = -1 + self.idx2num = [] + found_n_minus_one = False + for ln in f: + ln = ln.strip('\n') + ln = ln.split('\t')[0] + if idx == 0 and ln != "": + idx += 1 + self.idx2num.append(-1) + if idx == 1 and ln != "": + idx += 1 + self.idx2num.append(-1) + if ln == ",": + self.COMMA = idx + if ln == "": + self.DELIMITER = idx + if ln.isdigit(): + self.idx2num.append(int(ln)) + if int(ln) == n_node - 1: + found_n_minus_one = True + else: + self.idx2num.append(-1) + idx += 1 + + assert self.COMMA >= 0, "GGNN src_vocab must include ',' character" + assert self.DELIMITER >= 0, "GGNN src_vocab must include token" + assert found_n_minus_one, "GGNN src_vocab must include node numbers for edge connections" + + # Propogation Model + self.propogator = GGNNPropogator(self.state_dim, self.n_node, self.n_edge_types) + + self._initialization() + + # Initialize the bridge layer + self._initialize_bridge(rnn_type, self.state_dim, 1) + + # Token embedding + if src_ggnn_size > 0: + self.embed = nn.Sequential(nn.Linear(src_ggnn_size, src_word_vec_size), nn.LeakyReLU()) + assert self.src_ggnn_size >= self.DELIMITER, "Embedding input must be larger than vocabulary" + assert self.src_word_vec_size < self.state_dim, "Embedding size must be smaller than state_dim" + else: + assert self.DELIMITER < self.state_dim, "Vocabulary too large, consider -src_ggnn_size" + + @classmethod + def from_opt(cls, opt, embeddings): + """Alternate constructor.""" + return cls( + opt.rnn_type, + opt.src_word_vec_size, + opt.src_ggnn_size, + opt.state_dim, + opt.bidir_edges, + opt.n_edge_types, + opt.n_node, + opt.bridge_extra_node, + opt.n_steps, + opt.src_vocab, + ) + + def _initialization(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + m.weight.data.normal_(0.0, 0.02) + m.bias.data.fill_(0) + + def forward(self, src, lengths=None): + """See :func:`EncoderBase.forward()`""" + self._check_args(src, lengths) + nodes = self.n_node + batch_size = src.size()[1] + first_extra = np.zeros(batch_size, dtype=np.int32) + token_onehot = np.zeros( + (batch_size, nodes, self.src_ggnn_size if self.src_ggnn_size > 0 else self.state_dim), dtype=np.int32 + ) + edges = np.zeros((batch_size, nodes, nodes * self.n_edge_types * 2), dtype=np.int32) + npsrc = src[:, :, 0].cpu().data.numpy().astype(np.int32) + + # Initialize graph using formatted input sequence + for i in range(batch_size): + tokens_done = False + # Number of flagged nodes defines node count for this sample + # (Nodes can have no flags on them, but must be in 'flags' list). + flag_node = 0 + flags_done = False + edge = 0 + source_node = -1 + for j in range(len(npsrc)): + token = npsrc[j][i] + if not tokens_done: + if token == self.DELIMITER: + tokens_done = True + first_extra[i] = j + else: + token_onehot[i][j][token] = 1 + elif token == self.DELIMITER: + flag_node += 1 + flags_done = True + assert flag_node <= nodes, "Too many nodes with flags" + elif not flags_done: + # The total number of integers in the vocab should allow + # for all features and edges to be defined. + if token == self.COMMA: + flag_node = 0 + else: + num = self.idx2num[token] + if num >= 0: + token_onehot[i][flag_node][num + self.DELIMITER] = 1 + flag_node += 1 + elif token == self.COMMA: + edge += 1 + assert source_node == -1, f'Error in graph edge input: {source_node} unpaired' + assert edge < self.n_edge_types, "Too many edge types in input" + else: + num = self.idx2num[token] + if source_node < 0: + source_node = num + else: + edges[i][source_node][num + nodes * edge] = 1 + if self.bidir_edges: + edges[i][num][nodes * (edge + self.n_edge_types) + source_node] = 1 + source_node = -1 + + token_onehot = torch.from_numpy(token_onehot).float().to(src.device) + if self.src_ggnn_size > 0: + token_embed = self.embed(token_onehot) + prop_state = torch.cat( + ( + token_embed, + torch.zeros((batch_size, nodes, self.state_dim - self.src_word_vec_size)).float().to(src.device), + ), + 2, + ) + else: + prop_state = token_onehot + edges = torch.from_numpy(edges).float().to(src.device) + + for i_step in range(self.n_steps): + in_states = [] + out_states = [] + for i in range(self.n_edge_types): + in_states.append(self.in_fcs[i](prop_state)) + out_states.append(self.out_fcs[i](prop_state)) + in_states = torch.stack(in_states).transpose(0, 1).contiguous() + in_states = in_states.view(-1, nodes * self.n_edge_types, self.state_dim) + out_states = torch.stack(out_states).transpose(0, 1).contiguous() + out_states = out_states.view(-1, nodes * self.n_edge_types, self.state_dim) + + prop_state = self.propogator(in_states, out_states, prop_state, edges, nodes) + + prop_state = prop_state.transpose(0, 1) + if self.bridge_extra_node: + # Use first extra node as only source for decoder init + join_state = prop_state[first_extra, torch.arange(batch_size)] + else: + # Average all nodes to get bridge input + join_state = prop_state.mean(0) + join_state = torch.stack((join_state, join_state, join_state, join_state)) + join_state = (join_state, join_state) + + encoder_final = self._bridge(join_state) + + return encoder_final, prop_state, lengths + + def _initialize_bridge(self, rnn_type, hidden_size, num_layers): + + # LSTM has hidden and cell state, other only one + number_of_states = 2 if rnn_type == "LSTM" else 1 + # Total number of states + self.total_hidden_dim = hidden_size * num_layers + + # Build a linear layer for each + self.bridge = nn.ModuleList( + [nn.Linear(self.total_hidden_dim, self.total_hidden_dim, bias=True) for _ in range(number_of_states)] + ) + + def _bridge(self, hidden): + """Forward hidden state through bridge.""" + + def bottle_hidden(linear, states): + """ + Transform from 3D to 2D, apply linear and return initial size + """ + size = states.size() + result = linear(states.view(-1, self.total_hidden_dim)) + return F.leaky_relu(result).view(size) + + if isinstance(hidden, tuple): # LSTM + outs = tuple([bottle_hidden(layer, hidden[ix]) for ix, layer in enumerate(self.bridge)]) + else: + outs = bottle_hidden(self.bridge[0], hidden) + return outs diff --git a/mammoth/modules/layer_stack_encoder.py b/onmt/encoders/layer_stack_encoder.py similarity index 76% rename from mammoth/modules/layer_stack_encoder.py rename to onmt/encoders/layer_stack_encoder.py index a8de6dd4..77073fd9 100644 --- a/mammoth/modules/layer_stack_encoder.py +++ b/onmt/encoders/layer_stack_encoder.py @@ -1,10 +1,10 @@ from torch import nn from typing import Dict, List -from mammoth.modules.encoder import EncoderBase -from mammoth.models.adapters import Adapter, AdaptedTransformerEncoder -from mammoth.utils.misc import sequence_mask -from mammoth.distributed import DatasetMetadata +from onmt.encoders.encoder import EncoderBase +from onmt.models.adapters import Adapter, AdaptedTransformerEncoder +from onmt.utils.misc import sequence_mask +from onmt.utils.distributed import DatasetMetadata class LayerStackEncoder(EncoderBase): @@ -17,32 +17,32 @@ def __init__(self, embeddings, encoders): self._active: List[str] = [] @classmethod - def from_opts(cls, opts, embeddings, task_queue_manager): + def from_opt(cls, opt, embeddings, task_queue_manager): """Alternate constructor for use during training.""" encoders = nn.ModuleList() - for layer_stack_index, n_layers in enumerate(opts.enc_layers): + for layer_stack_index, n_layers in enumerate(opt.enc_layers): stacks = nn.ModuleDict() - is_on_top = layer_stack_index == len(opts.enc_layers) - 1 + is_on_top = layer_stack_index == len(opt.enc_layers) - 1 for module_id in task_queue_manager.get_encoders(layer_stack_index): if module_id in stacks: # several tasks using the same layer stack continue stacks[module_id] = AdaptedTransformerEncoder( n_layers, - opts.model_dim, - opts.heads, - opts.transformer_ff, - opts.dropout[0] if type(opts.dropout) is list else opts.dropout, + opt.enc_rnn_size, + opt.heads, + opt.transformer_ff, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, ( - opts.attention_dropout[0] - if type(opts.attention_dropout) is list - else opts.attention_dropout + opt.attention_dropout[0] + if type(opt.attention_dropout) is list + else opt.attention_dropout ), None, # embeddings, - opts.max_relative_positions, - pos_ffn_activation_fn=opts.pos_ffn_activation_fn, + opt.max_relative_positions, + pos_ffn_activation_fn=opt.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top + nn.LayerNorm(opt.enc_rnn_size, eps=1e-6) if is_on_top else nn.Identity() ) ) @@ -50,30 +50,30 @@ def from_opts(cls, opts, embeddings, task_queue_manager): return cls(embeddings, encoders) @classmethod - def from_trans_opt(cls, model_opts, embeddings, opt_stack): + def from_trans_opt(cls, model_opt, embeddings, opt_stack): """Alternate constructor for use during translation.""" encoders = nn.ModuleList() - for layer_stack_index, n_layers in enumerate(model_opts.enc_layers): + for layer_stack_index, n_layers in enumerate(model_opt.enc_layers): stacks = nn.ModuleDict() module_opts = opt_stack['encoder'][layer_stack_index] module_id = module_opts['id'] - is_on_top = layer_stack_index == len(model_opts.enc_layers) - 1 + is_on_top = layer_stack_index == len(model_opt.enc_layers) - 1 stacks[module_id] = AdaptedTransformerEncoder( n_layers, - model_opts.model_dim, - model_opts.heads, - model_opts.transformer_ff, - model_opts.dropout[0] if type(model_opts.dropout) is list else model_opts.dropout, + model_opt.enc_rnn_size, + model_opt.heads, + model_opt.transformer_ff, + model_opt.dropout[0] if type(model_opt.dropout) is list else model_opt.dropout, ( - model_opts.attention_dropout[0] - if type(model_opts.attention_dropout) is list - else model_opts.attention_dropout + model_opt.attention_dropout[0] + if type(model_opt.attention_dropout) is list + else model_opt.attention_dropout ), None, # embeddings, - model_opts.max_relative_positions, - pos_ffn_activation_fn=model_opts.pos_ffn_activation_fn, + model_opt.max_relative_positions, + pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(model_opts.model_dim, eps=1e-6) if is_on_top + nn.LayerNorm(model_opt.enc_rnn_size, eps=1e-6) if is_on_top else nn.Identity() ) ) diff --git a/mammoth/modules/mean_encoder.py b/onmt/encoders/mean_encoder.py similarity index 81% rename from mammoth/modules/mean_encoder.py rename to onmt/encoders/mean_encoder.py index 943a099e..ca903c99 100644 --- a/mammoth/modules/mean_encoder.py +++ b/onmt/encoders/mean_encoder.py @@ -1,6 +1,6 @@ """Define a minimal encoder.""" -from mammoth.modules.encoder import EncoderBase -from mammoth.utils.misc import sequence_mask +from onmt.encoders.encoder import EncoderBase +from onmt.utils.misc import sequence_mask import torch @@ -9,7 +9,7 @@ class MeanEncoder(EncoderBase): Args: num_layers (int): number of replicated layers - embeddings (mammoth.modules.Embeddings): embedding module to use + embeddings (onmt.modules.Embeddings): embedding module to use """ def __init__(self, num_layers, embeddings): @@ -18,9 +18,9 @@ def __init__(self, num_layers, embeddings): self.embeddings = embeddings @classmethod - def from_opts(cls, opts, embeddings): + def from_opt(cls, opt, embeddings): """Alternate constructor.""" - return cls(opts.enc_layers, embeddings) + return cls(opt.enc_layers, embeddings) def forward(self, src, lengths=None): """See :func:`EncoderBase.forward()`""" diff --git a/onmt/encoders/rnn_encoder.py b/onmt/encoders/rnn_encoder.py new file mode 100644 index 00000000..78271050 --- /dev/null +++ b/onmt/encoders/rnn_encoder.py @@ -0,0 +1,115 @@ +"""Define RNN-based encoders.""" +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn.utils.rnn import pack_padded_sequence as pack +from torch.nn.utils.rnn import pad_packed_sequence as unpack + +from onmt.encoders.encoder import EncoderBase +from onmt.utils.rnn_factory import rnn_factory + + +class RNNEncoder(EncoderBase): + """A generic recurrent neural network encoder. + + Args: + rnn_type (str): + style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] + bidirectional (bool) : use a bidirectional RNN + num_layers (int) : number of stacked layers + hidden_size (int) : hidden size of each layer + dropout (float) : dropout value for :class:`torch.nn.Dropout` + embeddings (onmt.modules.Embeddings): embedding module to use + """ + + def __init__( + self, rnn_type, bidirectional, num_layers, hidden_size, dropout=0.0, embeddings=None, use_bridge=False + ): + super(RNNEncoder, self).__init__() + assert embeddings is not None + + num_directions = 2 if bidirectional else 1 + assert hidden_size % num_directions == 0 + hidden_size = hidden_size // num_directions + self.embeddings = embeddings + + self.rnn, self.no_pack_padded_seq = rnn_factory( + rnn_type, + input_size=embeddings.embedding_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + bidirectional=bidirectional, + ) + + # Initialize the bridge layer + self.use_bridge = use_bridge + if self.use_bridge: + self._initialize_bridge(rnn_type, hidden_size, num_layers) + + @classmethod + def from_opt(cls, opt, embeddings): + """Alternate constructor.""" + return cls( + opt.rnn_type, + opt.brnn, + opt.enc_layers, + opt.enc_rnn_size, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + embeddings, + opt.bridge, + ) + + def forward(self, src, lengths=None): + """See :func:`EncoderBase.forward()`""" + self._check_args(src, lengths) + + emb = self.embeddings(src) + # s_len, batch, emb_dim = emb.size() + + packed_emb = emb + if lengths is not None and not self.no_pack_padded_seq: + # Lengths data is wrapped inside a Tensor. + lengths_list = lengths.view(-1).tolist() + packed_emb = pack(emb, lengths_list) + + memory_bank, encoder_final = self.rnn(packed_emb) + + if lengths is not None and not self.no_pack_padded_seq: + memory_bank = unpack(memory_bank)[0] + + if self.use_bridge: + encoder_final = self._bridge(encoder_final) + return encoder_final, memory_bank, lengths + + def _initialize_bridge(self, rnn_type, hidden_size, num_layers): + + # LSTM has hidden and cell state, other only one + number_of_states = 2 if rnn_type == "LSTM" else 1 + # Total number of states + self.total_hidden_dim = hidden_size * num_layers + + # Build a linear layer for each + self.bridge = nn.ModuleList( + [nn.Linear(self.total_hidden_dim, self.total_hidden_dim, bias=True) for _ in range(number_of_states)] + ) + + def _bridge(self, hidden): + """Forward hidden state through bridge.""" + + def bottle_hidden(linear, states): + """ + Transform from 3D to 2D, apply linear and return initial size + """ + size = states.size() + result = linear(states.view(-1, self.total_hidden_dim)) + return F.relu(result).view(size) + + if isinstance(hidden, tuple): # LSTM + outs = tuple([bottle_hidden(layer, hidden[ix]) for ix, layer in enumerate(self.bridge)]) + else: + outs = bottle_hidden(self.bridge[0], hidden) + return outs + + def update_dropout(self, dropout): + self.rnn.dropout = dropout diff --git a/mammoth/modules/transformer_encoder.py b/onmt/encoders/transformer.py similarity index 85% rename from mammoth/modules/transformer_encoder.py rename to onmt/encoders/transformer.py index 43fd92ca..c020aa7d 100644 --- a/mammoth/modules/transformer_encoder.py +++ b/onmt/encoders/transformer.py @@ -4,11 +4,11 @@ import torch.nn as nn -from mammoth.modules.encoder import EncoderBase -from mammoth.modules import MultiHeadedAttention -from mammoth.modules.position_ffn import PositionwiseFeedForward -from mammoth.modules.position_ffn import ActivationFunction -from mammoth.utils.misc import sequence_mask +from onmt.encoders.encoder import EncoderBase +from onmt.modules import MultiHeadedAttention +from onmt.modules.position_ffn import PositionwiseFeedForward +from onmt.modules.position_ffn import ActivationFunction +from onmt.utils.misc import sequence_mask class TransformerEncoderLayer(nn.Module): @@ -88,7 +88,7 @@ class TransformerEncoder(EncoderBase): heads (int): number of heads d_ff (int): size of the inner FF layer dropout (float): dropout parameters - embeddings (mammoth.modules.Embeddings): + embeddings (onmt.modules.Embeddings): embeddings to use, should have positional encodings pos_ffn_activation_fn (ActivationFunction): activation function choice for PositionwiseFeedForward layer @@ -133,20 +133,20 @@ def __init__( self.layer_norm = layer_norm_module @classmethod - def from_opts(cls, opts, embeddings, is_on_top=False): + def from_opt(cls, opt, embeddings, is_on_top=False): """Alternate constructor.""" return cls( - opts.enc_layers, - opts.model_dim, - opts.heads, - opts.transformer_ff, - opts.dropout[0] if type(opts.dropout) is list else opts.dropout, - opts.attention_dropout[0] if type(opts.attention_dropout) is list else opts.attention_dropout, + opt.enc_layers, + opt.enc_rnn_size, + opt.heads, + opt.transformer_ff, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout, embeddings, - opts.max_relative_positions, - pos_ffn_activation_fn=opts.pos_ffn_activation_fn, + opt.max_relative_positions, + pos_ffn_activation_fn=opt.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top + nn.LayerNorm(opt.enc_rnn_size, eps=1e-6) if is_on_top else nn.Identity() ) ) diff --git a/mammoth/inputters/__init__.py b/onmt/inputters/__init__.py similarity index 61% rename from mammoth/inputters/__init__.py rename to onmt/inputters/__init__.py index 4d3d1dbe..8c32e4d3 100644 --- a/mammoth/inputters/__init__.py +++ b/onmt/inputters/__init__.py @@ -1,4 +1,4 @@ -"""The point of this package is to provide: +"""The point of this package is to provide a minimal viable product with: - vocab loading (cf. vocab.py) - token-counts based batch sampler (cf. dataloader.py) - on the fly pad, bos, eos, unk handling (cf. dataset.py) @@ -6,9 +6,9 @@ - multiple parallel corpora, in accordance with TaskDistributor (cf. distributed.py) """ -from mammoth.inputters.dataloader import build_dataloader, DynamicDatasetIter -from mammoth.inputters.dataset import get_corpus, build_vocab_counts, ParallelCorpus -from mammoth.inputters.vocab import get_vocab, DEFAULT_SPECIALS +from onmt.inputters.dataloader import build_dataloader, DynamicDatasetIter +from onmt.inputters.dataset import get_corpus, build_vocab_counts, ParallelCorpus +from onmt.inputters.vocab import get_vocab, DEFAULT_SPECIALS __all__ = [ diff --git a/mammoth/inputters/dataloader.py b/onmt/inputters/dataloader.py similarity index 96% rename from mammoth/inputters/dataloader.py rename to onmt/inputters/dataloader.py index 7e26987b..ed707c94 100644 --- a/mammoth/inputters/dataloader.py +++ b/onmt/inputters/dataloader.py @@ -4,8 +4,8 @@ import torch -from mammoth.inputters.dataset import get_corpus -from mammoth.utils.logging import logger +from onmt.inputters.dataset import get_corpus +from onmt.utils.logging import logger def infinite_iterator(iterable): @@ -13,7 +13,7 @@ def infinite_iterator(iterable): def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=None, cycle=True, as_iter=True): - """Convert an mammoth.inputters.ParallelCorpus into an infinite iterator of batches""" + """Convert an onmt.inputters.ParallelCorpus into an infinite iterator of batches""" if not cycle: loader = InferenceBatcher(dataset, batch_size) else: @@ -187,7 +187,7 @@ class DynamicDatasetIter(object): batch_size (int): numbers of examples in a batch; batch_size_multiple (int): make batch size multiply of this; data_type (str): input data type, currently only text; - pool_size (int): accum this number of examples in a dynamic dataset; + bucket_size (int): accum this number of examples in a dynamic dataset; skip_empty_level (str): security level when encouter empty line; stride (int): iterate data files with this stride; offset (int): iterate data files with this offset. @@ -209,7 +209,7 @@ def __init__( batch_size, batch_size_multiple, data_type="text", - pool_size=2048, + bucket_size=2048, n_buckets=1024, skip_empty_level='warning', ): @@ -225,7 +225,7 @@ def __init__( self.batch_size = batch_size self.batch_size_multiple = batch_size_multiple self.device = 'cpu' - self.pool_size = pool_size + self.bucket_size = bucket_size self.n_buckets = n_buckets if skip_empty_level not in ['silent', 'warning', 'error']: raise ValueError(f"Invalid argument skip_empty_level={skip_empty_level}") @@ -242,7 +242,7 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra return cls( task_queue_manager, opts, - opts.tasks, + opts.data, transforms_cls, vocabs_dict, is_train, @@ -250,7 +250,7 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra batch_size, batch_size_multiple, data_type=opts.data_type, - pool_size=opts.pool_size, + bucket_size=opts.bucket_size, n_buckets=opts.n_buckets, skip_empty_level=opts.skip_empty_level, ) @@ -275,7 +275,7 @@ def _init_datasets(self): # Case 2: we are validation (hence self.is_train := False), we need an iterator # if and only the task defines validation data, i.e. if the key `path_valid_src` # is defined - if self.is_train or self.opts.tasks[task.corpus_id].get('path_valid_src', None) is not None: + if self.is_train or self.opts.data[task.corpus_id].get('path_valid_src', None) is not None: corpus = get_corpus( self.opts, task, src_vocab, tgt_vocab, is_train=self.is_train ).to(device) @@ -285,7 +285,7 @@ def _init_datasets(self): corpus, self.batch_size, self.batch_type, - self.pool_size, + self.bucket_size, n_buckets=self.n_buckets, cycle=self.is_train, as_iter=self.is_train, diff --git a/mammoth/inputters/dataset.py b/onmt/inputters/dataset.py similarity index 95% rename from mammoth/inputters/dataset.py rename to onmt/inputters/dataset.py index 6bb2d1a9..a8044dec 100644 --- a/mammoth/inputters/dataset.py +++ b/onmt/inputters/dataset.py @@ -8,10 +8,10 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.data import IterableDataset -from mammoth.constants import DefaultTokens -from mammoth.transforms import TransformPipe, get_transforms_cls, make_transforms -from mammoth.utils.logging import logger -from mammoth.inputters.vocab import Vocab +from onmt.constants import DefaultTokens +from onmt.transforms import TransformPipe, get_transforms_cls, make_transforms +from onmt.utils.logging import logger +from onmt.inputters.vocab import Vocab @dataclass @@ -102,7 +102,7 @@ def __init__( self.is_train = is_train self.corpus_id = task.corpus_id - # FIXME: most likely redundant with mammoth.transforms.tokenize + # FIXME: most likely redundant with onmt.transforms.tokenize def _tokenize(self, string, side='src'): """Split string, accompanied by a drumroll""" return string.split() @@ -177,7 +177,7 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool = # get transform classes to infer special tokens # FIXME ensure TQM properly initializes transform with global if necessary vocabs = {'src': src_vocab, 'tgt': tgt_vocab} - corpus_opts = opts.tasks[task.corpus_id] + corpus_opts = opts.data[task.corpus_id] transforms_to_apply = corpus_opts.get('transforms', None) transforms_to_apply = transforms_to_apply or opts.get('transforms', None) transforms_to_apply = transforms_to_apply or [] @@ -245,8 +245,8 @@ def build_vocab_counts(opts, corpus_id, transforms, n_sample=3): corpora = { corpus_id: read_examples_from_files( - opts.tasks[corpus_id]["path_src"], - opts.tasks[corpus_id]["path_tgt"], + opts.data[corpus_id]["path_src"], + opts.data[corpus_id]["path_tgt"], # FIXME this is likely not working transforms_fn=TransformPipe(transforms).apply if transforms else lambda x: x, ) diff --git a/mammoth/inputters/vocab.py b/onmt/inputters/vocab.py similarity index 97% rename from mammoth/inputters/vocab.py rename to onmt/inputters/vocab.py index cc3cbd31..b2ed194b 100644 --- a/mammoth/inputters/vocab.py +++ b/onmt/inputters/vocab.py @@ -3,8 +3,8 @@ import itertools import os -from mammoth.utils.logging import logger -from mammoth.constants import DefaultTokens +from onmt.utils.logging import logger +from onmt.constants import DefaultTokens DEFAULT_SPECIALS = (DefaultTokens.BOS, DefaultTokens.EOS, DefaultTokens.UNK, DefaultTokens.PAD) diff --git a/mammoth/model_builder.py b/onmt/model_builder.py similarity index 69% rename from mammoth/model_builder.py rename to onmt/model_builder.py index 35dfec85..69abb0a7 100644 --- a/mammoth/model_builder.py +++ b/onmt/model_builder.py @@ -7,92 +7,95 @@ from torch.nn.init import xavier_uniform_ from collections import defaultdict +# from torchtext.legacy.data import Field -import mammoth.modules +import onmt.modules -from mammoth.models.adapters import ( +from onmt.models.adapters import ( Adapter, EncoderAdapterLayer, DecoderAdapterLayer, ) -from mammoth.constants import ModelTask, DefaultTokens -from mammoth.modules.layer_stack_decoder import LayerStackDecoder -from mammoth.modules.layer_stack_encoder import LayerStackEncoder -from mammoth.modules import Embeddings -from mammoth.modules.embeddings import PluggableEmbeddings -from mammoth.modules.util_class import Cast -from mammoth.utils.logging import logger -from mammoth.utils.misc import use_gpu -from mammoth.utils.module_splitter import _combine_ordered_dicts -from mammoth.utils.parse import ArgumentParser +from onmt.constants import ModelTask, DefaultTokens +from onmt.decoders.layer_stack_decoder import LayerStackDecoder +from onmt.encoders.layer_stack_encoder import LayerStackEncoder +from onmt.modules import Embeddings +from onmt.modules.embeddings import PluggableEmbeddings +from onmt.modules.util_class import Cast +from onmt.utils.logging import logger +from onmt.utils.misc import use_gpu +from onmt.utils.module_splitter import _combine_ordered_dicts +from onmt.utils.parse import ArgumentParser -from mammoth.modules.attention_bridge import AttentionBridge +from onmt.attention_bridge import AttentionBridge -def build_embeddings(opts, vocab, for_encoder=True): +def build_embeddings(opt, vocab, for_encoder=True): """ Args: - opts: the option in current environment. + opt: the option in current environment. vocab: stoi-ish object. for_encoder(bool): build Embeddings for encoder or decoder? """ + emb_dim = opt.src_word_vec_size if for_encoder else opt.tgt_word_vec_size word_padding_idx = vocab.stoi[DefaultTokens.PAD] - opts.word_padding_idx = word_padding_idx + opt.word_padding_idx = word_padding_idx - freeze_word_vecs = opts.freeze_word_vecs_enc if for_encoder else opts.freeze_word_vecs_dec + freeze_word_vecs = opt.freeze_word_vecs_enc if for_encoder else opt.freeze_word_vecs_dec emb = Embeddings( - word_vec_size=opts.model_dim, - position_encoding=opts.position_encoding, - dropout=opts.dropout[0] if type(opts.dropout) is list else opts.dropout, + word_vec_size=emb_dim, + position_encoding=opt.position_encoding, + dropout=opt.dropout[0] if type(opt.dropout) is list else opt.dropout, word_padding_idx=word_padding_idx, word_vocab_size=len(vocab), + sparse=opt.optim == "sparseadam", freeze_word_vecs=freeze_word_vecs, ) return emb -def build_encoder(opts, embeddings, task_queue_manager): +def build_encoder(opt, embeddings, task_queue_manager): """ Various encoder dispatcher function. Args: - opts: the option in current environment. + opt: the option in current environment. embeddings (Embeddings): vocab embeddings for this encoder. """ - assert opts.encoder_type == 'transformer', 'Only Transformer is supported' - return LayerStackEncoder.from_opts(opts, embeddings, task_queue_manager) + assert opt.encoder_type == 'transformer', 'Only Transformer is supported' + return LayerStackEncoder.from_opt(opt, embeddings, task_queue_manager) -def build_decoder(opts, embeddings, task_queue_manager): +def build_decoder(opt, embeddings, task_queue_manager): """ Various decoder dispatcher function. Args: - opts: the option in current environment. + opt: the option in current environment. embeddings (Embeddings): vocab embeddings for this decoder. """ - assert opts.decoder_type == 'transformer', 'Only Transformer is supported' - return LayerStackDecoder.from_opts(opts, embeddings, task_queue_manager) + assert opt.decoder_type == 'transformer', 'Only Transformer is supported' + return LayerStackDecoder.from_opt(opt, embeddings, task_queue_manager) -def load_test_multitask_model(opts, model_path=None): +def load_test_multitask_model(opt, model_path=None): """If a checkpoint ending with ".pt" returns a full model otherwise it builds a bilingual model""" if model_path is None: - model_path = opts.models[0] + model_path = opt.models[0] - opts.lang_pair = opts.lang_pair if opts.lang_pair else f'{opts.src_lang}-{opts.tgt_lang}' + opt.lang_pair = opt.lang_pair if opt.lang_pair else f'{opt.src_lang}-{opt.tgt_lang}' if model_path.endswith('.pt'): - return load_test_model(opts, model_path) + return load_test_model(opt, model_path) else: checkpoint_modules = [ - (f'encoder.embeddings.embeddings_{opts.src_lang}.', f'src_embeddings_{opts.src_lang}'), - (f'decoder.embeddings.embeddings_{opts.tgt_lang}.', f'tgt_embeddings_{opts.tgt_lang}'), - (f'generator.generator_{opts.tgt_lang}.', f'generator_{opts.tgt_lang}'), + (f'encoder.embeddings.embeddings_{opt.src_lang}.', f'src_embeddings_{opt.src_lang}'), + (f'decoder.embeddings.embeddings_{opt.tgt_lang}.', f'tgt_embeddings_{opt.tgt_lang}'), + (f'generator.generator_{opt.tgt_lang}.', f'generator_{opt.tgt_lang}'), ('attention_bridge.', 'attention_bridge'), ] - for layer_stack_idx, layer_stack_opt in enumerate(opts.stack['encoder']): + for layer_stack_idx, layer_stack_opt in enumerate(opt.stack['encoder']): layer_stack_key = layer_stack_opt['id'] checkpoint_modules.append( ( @@ -107,7 +110,7 @@ def load_test_multitask_model(opts, model_path=None): f'encoder_adapter_{layer_stack_idx}_{layer_stack_key}_{adapter_group}_{sub_id}' ) ) - for layer_stack_idx, layer_stack_opt in enumerate(opts.stack['decoder']): + for layer_stack_idx, layer_stack_opt in enumerate(opt.stack['decoder']): layer_stack_key = layer_stack_opt['id'] checkpoint_modules.append( ( @@ -128,8 +131,8 @@ def load_test_multitask_model(opts, model_path=None): (prefix, f'{model_path}_{key}.pt') for (prefix, key) in checkpoint_modules ] - opts.model_frame = model_path + '_frame.pt' - frame = torch.load(opts.model_frame, map_location=lambda storage, loc: storage) + opt.model_frame = model_path + '_frame.pt' + frame = torch.load(opt.model_frame, map_location=lambda storage, loc: storage) checkpoint_state_dicts = { prefix: torch.load(path, map_location=lambda storage, loc: storage) @@ -139,20 +142,20 @@ def load_test_multitask_model(opts, model_path=None): combined_state_dict = _combine_ordered_dicts(checkpoint_state_dicts) vocabs_dict = { - 'src': frame["vocab"].get(('src', opts.src_lang)), - 'tgt': frame["vocab"].get(('tgt', opts.tgt_lang)), + 'src': frame["vocab"].get(('src', opt.src_lang)), + 'tgt': frame["vocab"].get(('tgt', opt.tgt_lang)), } # FIXME # fields["indices"] = Field(use_vocab=False, dtype=torch.long, sequential=False) - model_opts = ArgumentParser.ckpt_model_opts(frame['opts']) + model_opt = ArgumentParser.ckpt_model_opts(frame['opt']) # Avoid functionality on inference - model_opts.update_vocab = False + model_opt.update_vocab = False model = create_bilingual_model( - src_lang=opts.src_lang, - tgt_lang=opts.tgt_lang, - opt_stack=opts.stack, - model_opts=model_opts, + src_lang=opt.src_lang, + tgt_lang=opt.tgt_lang, + opt_stack=opt.stack, + model_opt=model_opt, vocabs_dict=vocabs_dict ) model_params = {name for name, p in model.named_parameters()} @@ -165,24 +168,24 @@ def load_test_multitask_model(opts, model_path=None): if key not in combined_state_dict: print(f'Key missing {key}') model.load_state_dict(combined_state_dict) - device = torch.device("cuda" if use_gpu(opts) else "cpu") + device = torch.device("cuda" if use_gpu(opt) else "cpu") model.to(device) model.eval() - return vocabs_dict, model, model_opts + return vocabs_dict, model, model_opt -def load_test_model(opts, model_path=None): +def load_test_model(opt, model_path=None): if model_path is None: - model_path = opts.models[0] + model_path = opt.models[0] - if len(opts.models) > 1: - model_path_enc = opts.models[0] + if len(opt.models) > 1: + model_path_enc = opt.models[0] checkpoint = torch.load(model_path_enc, map_location=lambda storage, loc: storage) model = checkpoint['whole_model'] - model_path_dec = opts.models[1] + model_path_dec = opt.models[1] model_dec = torch.load(model_path_dec, map_location=lambda storage, loc: storage)['whole_model'] model.decoder = model_dec.decoder model.generator = model_dec.generator @@ -190,17 +193,17 @@ def load_test_model(opts, model_path=None): checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model = checkpoint['whole_model'] - model_opts = ArgumentParser.ckpt_model_opts(checkpoint['opts']) - ArgumentParser.update_model_opts(model_opts) - ArgumentParser.validate_model_opts(model_opts) + model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt']) + ArgumentParser.update_model_opts(model_opt) + ArgumentParser.validate_model_opts(model_opt) vocabs = checkpoint['vocab'] print("VOCABS") print(vocabs) - if opts.gpu != -1: + if opt.gpu != -1: device = torch.device("cuda") model.to(device) - lang_pair = opts.lang_pair + lang_pair = opt.lang_pair src_lang, tgt_lang = lang_pair.split("-") # FIXME vocabs_dict = {} @@ -210,48 +213,48 @@ def load_test_model(opts, model_path=None): # fields["indices"] = indices # Avoid functionality on inference - model_opts.update_vocab = False + model_opt.update_vocab = False - if opts.fp32: + if opt.fp32: model.float() - elif opts.int8: - if opts.gpu >= 0: + elif opt.int8: + if opt.gpu >= 0: raise ValueError("Dynamic 8-bit quantization is not supported on GPU") torch.quantization.quantize_dynamic(model, inplace=True) model.eval() model.generator.eval() - return vocabs_dict, model, model_opts + return vocabs_dict, model, model_opt def create_bilingual_model( - src_lang, tgt_lang, opt_stack, model_opts, vocabs_dict + src_lang, tgt_lang, opt_stack, model_opt, vocabs_dict ): """For translation.""" generators_md = nn.ModuleDict() - src_emb = build_src_emb(model_opts, vocabs_dict['src']) - tgt_emb = build_tgt_emb(model_opts, vocabs_dict['tgt']) + src_emb = build_src_emb(model_opt, vocabs_dict['src']) + tgt_emb = build_tgt_emb(model_opt, vocabs_dict['tgt']) pluggable_src_emb = PluggableEmbeddings({src_lang: src_emb}) pluggable_tgt_emb = PluggableEmbeddings({tgt_lang: tgt_emb}) pluggable_src_emb.activate(src_lang) pluggable_tgt_emb.activate(tgt_lang) - encoder = LayerStackEncoder.from_trans_opt(model_opts, pluggable_src_emb, opt_stack) - decoder = LayerStackDecoder.from_trans_opt(model_opts, pluggable_tgt_emb, opt_stack) - generator = build_generator(model_opts, len(vocabs_dict['tgt']), tgt_emb) + encoder = LayerStackEncoder.from_trans_opt(model_opt, pluggable_src_emb, opt_stack) + decoder = LayerStackDecoder.from_trans_opt(model_opt, pluggable_tgt_emb, opt_stack) + generator = build_generator(model_opt, len(vocabs_dict['tgt']), tgt_emb) generators_md.add_module(f'generator_{tgt_lang}', generator) - attention_bridge = AttentionBridge.from_opts(model_opts) + attention_bridge = AttentionBridge.from_opt(model_opt) - nmt_model = mammoth.models.NMTModel( + nmt_model = onmt.models.NMTModel( encoder=encoder, decoder=decoder, attention_bridge=attention_bridge ) - if uses_adapters(model_opts): + if uses_adapters(model_opt): logger.info('Creating adapters...') - create_bilingual_adapters(nmt_model, model_opts, src_lang, tgt_lang, opt_stack) + create_bilingual_adapters(nmt_model, model_opt, src_lang, tgt_lang, opt_stack) else: logger.info('Does not use adapters...') print('built model:') @@ -261,18 +264,18 @@ def create_bilingual_model( return nmt_model -def build_src_emb(model_opts, src_vocab): +def build_src_emb(model_opt, src_vocab): # Build embeddings. - if model_opts.model_type == "text": - src_emb = build_embeddings(model_opts, src_vocab) + if model_opt.model_type == "text": + src_emb = build_embeddings(model_opt, src_vocab) else: src_emb = None return src_emb -def build_tgt_emb(model_opts, tgt_vocab): +def build_tgt_emb(model_opt, tgt_vocab): # Build embeddings. - tgt_emb = build_embeddings(model_opts, tgt_vocab, for_encoder=False) + tgt_emb = build_embeddings(model_opt, tgt_vocab, for_encoder=False) # if share_embeddings: # tgt_emb.word_lut.weight = src_emb.word_lut.weight @@ -281,15 +284,15 @@ def build_tgt_emb(model_opts, tgt_vocab): def build_task_specific_model( - model_opts, + model_opt, vocabs_dict, device, task_queue_manager, checkpoint, ): logger.info(f'TaskQueueManager: {task_queue_manager}') - if not model_opts.model_task == ModelTask.SEQ2SEQ: - raise ValueError(f"Only ModelTask.SEQ2SEQ works - {model_opts.model_task} task") + if not model_opt.model_task == ModelTask.SEQ2SEQ: + raise ValueError(f"Only ModelTask.SEQ2SEQ works - {model_opt.model_task} task") src_embs = dict() tgt_embs = dict() @@ -298,42 +301,42 @@ def build_task_specific_model( # FIXME: it's getting late and I just want this to compile for side, lang, _, vocab in task_queue_manager.get_vocabs(side='src', vocabs_dict=vocabs_dict): - src_emb = build_src_emb(model_opts, vocab) + src_emb = build_src_emb(model_opt, vocab) src_embs[lang] = src_emb pluggable_src_emb = PluggableEmbeddings(src_embs) - encoder = build_only_enc(model_opts, pluggable_src_emb, task_queue_manager) + encoder = build_only_enc(model_opt, pluggable_src_emb, task_queue_manager) for side, lang, _, vocab in task_queue_manager.get_vocabs(side='tgt', vocabs_dict=vocabs_dict): - tgt_emb = build_tgt_emb(model_opts, vocab) + tgt_emb = build_tgt_emb(model_opt, vocab) tgt_embs[lang] = tgt_emb - generator = build_generator(model_opts, len(vocab), tgt_emb) + generator = build_generator(model_opt, len(vocab), tgt_emb) generators_md.add_module(f'generator_{lang}', generator) pluggable_tgt_emb = PluggableEmbeddings(tgt_embs) - decoder = build_only_dec(model_opts, pluggable_tgt_emb, task_queue_manager) + decoder = build_only_dec(model_opt, pluggable_tgt_emb, task_queue_manager) # TODO: implement hierarchical approach to layer sharing - attention_bridge = AttentionBridge.from_opts(model_opts) + attention_bridge = AttentionBridge.from_opt(model_opt) - if model_opts.param_init != 0.0: + if model_opt.param_init != 0.0: for p in attention_bridge.parameters(): - p.data.uniform_(-model_opts.param_init, model_opts.param_init) - if model_opts.param_init_glorot: + p.data.uniform_(-model_opt.param_init, model_opt.param_init) + if model_opt.param_init_glorot: for p in attention_bridge.parameters(): if p.dim() > 1: xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) - if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam': + if model_opt.model_dtype == 'fp16' and model_opt.optim == 'fusedadam': attention_bridge.half() - nmt_model = mammoth.models.NMTModel( + nmt_model = onmt.models.NMTModel( encoder=encoder, decoder=decoder, attention_bridge=attention_bridge ) - if uses_adapters(model_opts): + if uses_adapters(model_opt): logger.info('Creating adapters...') - create_all_adapters(nmt_model, model_opts, task_queue_manager) + create_all_adapters(nmt_model, model_opt, task_queue_manager) print('built model:') print(nmt_model) @@ -362,54 +365,57 @@ def has_grad_hook(module, input, output) -> None: return nmt_model, generators_md -def build_only_enc(model_opts, src_emb, task_queue_manager): +def build_only_enc(model_opt, src_emb, task_queue_manager): """Truly only builds encoder: no embeddings""" - encoder = build_encoder(model_opts, src_emb, task_queue_manager) - if model_opts.param_init != 0.0: + encoder = build_encoder(model_opt, src_emb, task_queue_manager) + if model_opt.param_init != 0.0: for p in encoder.parameters(): - p.data.uniform_(-model_opts.param_init, model_opts.param_init) - if model_opts.param_init_glorot: + p.data.uniform_(-model_opt.param_init, model_opt.param_init) + if model_opt.param_init_glorot: for p in encoder.parameters(): if p.dim() > 1: xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) - if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam': + if model_opt.model_dtype == 'fp16' and model_opt.optim == 'fusedadam': encoder.half() return encoder -def build_only_dec(model_opts, tgt_emb, task_queue_manager): - decoder = build_decoder(model_opts, tgt_emb, task_queue_manager) +def build_only_dec(model_opt, tgt_emb, task_queue_manager): + decoder = build_decoder(model_opt, tgt_emb, task_queue_manager) - if model_opts.param_init != 0.0: + if model_opt.param_init != 0.0: for p in decoder.parameters(): - p.data.uniform_(-model_opts.param_init, model_opts.param_init) - if model_opts.param_init_glorot: + p.data.uniform_(-model_opt.param_init, model_opt.param_init) + if model_opt.param_init_glorot: for p in decoder.parameters(): if p.dim() > 1: xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) - if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam': + if model_opt.model_dtype == 'fp16' and model_opt.optim == 'fusedadam': decoder.half() return decoder -def build_generator(model_opts, n_tgts, tgt_emb): +def build_generator(model_opt, n_tgts, tgt_emb): # Build Generator. - assert not model_opts.copy_attn, 'copy_attn not supported' - gen_func = nn.LogSoftmax(dim=-1) + assert not model_opt.copy_attn, 'copy_attn not supported' + if model_opt.generator_function == "sparsemax": + gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) + else: + gen_func = nn.LogSoftmax(dim=-1) generator = nn.Sequential( - nn.Linear(model_opts.model_dim, n_tgts), Cast(torch.float32), gen_func + nn.Linear(model_opt.dec_rnn_size, n_tgts), Cast(torch.float32), gen_func ) - if model_opts.share_decoder_embeddings: + if model_opt.share_decoder_embeddings: generator[0].weight = tgt_emb.word_lut.weight - if model_opts.param_init != 0.0: + if model_opt.param_init != 0.0: for p in generator.parameters(): - p.data.uniform_(-model_opts.param_init, model_opts.param_init) - if model_opts.param_init_glorot: + p.data.uniform_(-model_opt.param_init, model_opt.param_init) + if model_opt.param_init_glorot: for p in generator.parameters(): if p.dim() > 1: xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) @@ -449,7 +455,7 @@ def build_generator(model_opts, n_tgts, tgt_emb): def build_base_model_langspec( - model_opts, + model_opt, vocabs_dict, gpu, task_queue_manager, @@ -458,10 +464,10 @@ def build_base_model_langspec( """Build a model from opts. Args: - model_opts: the option loaded from checkpoint. It's important that + model_opt: the option loaded from checkpoint. It's important that the opts have been updated and validated. See - :class:`mammoth.utils.parse.ArgumentParser`. - vocabs_dict (dict[str, mammoth.inputters.Vocab]): + :class:`onmt.utils.parse.ArgumentParser`. + vocabs_dict (dict[str, onmt.inputters.Vocab]): `Vocab` objects for the model. gpu (bool): whether to use gpu. checkpoint: the model gnerated by train phase, or a resumed snapshot @@ -474,9 +480,9 @@ def build_base_model_langspec( # for back compat when attention_dropout was not defined try: - model_opts.attention_dropout + model_opt.attention_dropout except AttributeError: - model_opts.attention_dropout = model_opts.dropout + model_opt.attention_dropout = model_opt.dropout # Build Model logger.info("MODEL BUILDER") @@ -486,7 +492,7 @@ def build_base_model_langspec( device = torch.device("cpu") logger.info(device) model, generators_md = build_task_specific_model( - model_opts=model_opts, + model_opt=model_opt, vocabs_dict=vocabs_dict, device=device, task_queue_manager=task_queue_manager, @@ -499,11 +505,11 @@ def build_base_model_langspec( return model, generators_md -def uses_adapters(opts): - return 'adapters' in opts and opts.adapters +def uses_adapters(opt): + return 'adapters' in opt and opt.adapters -def create_all_adapters(model, opts, task_queue_manager): +def create_all_adapters(model, opt, task_queue_manager): my_enc_adapter_ids = set() my_dec_adapter_ids = set() adapter_to_encoder_ids = defaultdict(set) @@ -519,7 +525,7 @@ def create_all_adapters(model, opts, task_queue_manager): adapter_to_decoder_ids[adapter_id].add(tuple(task.decoder_id)) _create_adapters( model, - opts, + opt, my_enc_adapter_ids, adapter_to_encoder_ids, my_dec_adapter_ids, @@ -527,7 +533,7 @@ def create_all_adapters(model, opts, task_queue_manager): ) -def create_bilingual_adapters(model, opts, src_lang, tgt_lang, opt_stack): +def create_bilingual_adapters(model, opt, src_lang, tgt_lang, opt_stack): my_enc_adapter_ids = [] my_dec_adapter_ids = [] adapter_to_encoder_ids = {} @@ -550,7 +556,7 @@ def create_bilingual_adapters(model, opts, src_lang, tgt_lang, opt_stack): _create_adapters( model, - opts, + opt, my_enc_adapter_ids, adapter_to_encoder_ids, my_dec_adapter_ids, @@ -560,7 +566,7 @@ def create_bilingual_adapters(model, opts, src_lang, tgt_lang, opt_stack): def _create_adapters( model, - opts, + opt, my_enc_adapter_ids, adapter_to_encoder_ids, my_dec_adapter_ids, @@ -568,14 +574,14 @@ def _create_adapters( ): my_enc_adapter_ids = [tuple(item) for item in my_enc_adapter_ids] my_dec_adapter_ids = [tuple(item) for item in my_dec_adapter_ids] - for adapter_group, adapter_opts in opts.adapters['encoder'].items(): + for adapter_group, adapter_opts in opt.adapters['encoder'].items(): layer_stack_index = adapter_opts['layer_stack_index'] for sub_id in adapter_opts['ids']: adapter_id_long = (layer_stack_index, adapter_group, sub_id) if adapter_id_long not in my_enc_adapter_ids: continue adapter = Adapter(adapter_group, sub_id) - input_dim = opts.model_dim + input_dim = opt.rnn_size hidden_dim = adapter_opts['hidden_size'] # all stacks to which this adapter should be added @@ -596,14 +602,14 @@ def _create_adapters( layer_stack_index=layer_stack_index, module_ids=adapted_stacks, ) - for adapter_group, adapter_opts in opts.adapters['decoder'].items(): + for adapter_group, adapter_opts in opt.adapters['decoder'].items(): layer_stack_index = adapter_opts['layer_stack_index'] for sub_id in adapter_opts['ids']: adapter_id_long = (layer_stack_index, adapter_group, sub_id) if adapter_id_long not in my_dec_adapter_ids: continue adapter = Adapter(adapter_group, sub_id) - input_dim = opts.model_dim + input_dim = opt.rnn_size hidden_dim = adapter_opts['hidden_size'] adapted_stacks = set( @@ -625,12 +631,12 @@ def _create_adapters( ) -def build_model(model_opts, opts, vocabs_dict, task_queue_manager, checkpoint): +def build_model(model_opt, opt, vocabs_dict, task_queue_manager, checkpoint): logger.info('Building model...') model, generators_md = build_base_model_langspec( - model_opts=model_opts, + model_opt=model_opt, vocabs_dict=vocabs_dict, - gpu=use_gpu(opts), + gpu=use_gpu(opt), task_queue_manager=task_queue_manager, checkpoint=checkpoint, ) diff --git a/onmt/models/__init__.py b/onmt/models/__init__.py new file mode 100644 index 00000000..7543dfe3 --- /dev/null +++ b/onmt/models/__init__.py @@ -0,0 +1,5 @@ +"""Module defining models.""" +from onmt.models.model_saver import build_model_saver, ModelSaver +from onmt.models.model import NMTModel, LanguageModel + +__all__ = ["build_model_saver", "ModelSaver", "NMTModel", "LanguageModel"] diff --git a/mammoth/models/adapters.py b/onmt/models/adapters.py similarity index 98% rename from mammoth/models/adapters.py rename to onmt/models/adapters.py index c2e90c5d..26311237 100644 --- a/mammoth/models/adapters.py +++ b/onmt/models/adapters.py @@ -8,9 +8,9 @@ from abc import ABC from collections import defaultdict -from mammoth.modules import TransformerEncoder -from mammoth.modules import TransformerDecoder -from mammoth.rmsnorm_torch import RMSNorm +from onmt.encoders import TransformerEncoder +from onmt.decoders import TransformerDecoder +from onmt.rmsnorm_torch import RMSNorm class AdapterLayer(ABC, nn.Module): diff --git a/mammoth/models/model.py b/onmt/models/model.py similarity index 56% rename from mammoth/models/model.py rename to onmt/models/model.py index 27bdf44a..36097e11 100644 --- a/mammoth/models/model.py +++ b/onmt/models/model.py @@ -48,8 +48,8 @@ class NMTModel(BaseModel): Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic encoder + decoder model. Args: - encoder (mammoth.encoders.EncoderBase): an encoder object - decoder (mammoth.decoders.DecoderBase): a decoder object + encoder (onmt.encoders.EncoderBase): an encoder object + decoder (onmt.decoders.DecoderBase): a decoder object """ def __init__(self, encoder, decoder, attention_bridge): @@ -102,3 +102,68 @@ def count_parameters(self, log=print): log('decoder: {}'.format(dec)) log('* number of parameters: {}'.format(enc + dec)) return enc, dec + + +class LanguageModel(BaseModel): + """ + Core trainable object in OpenNMT. Implements a trainable interface + for a simple, generic decoder only model. + Currently TransformerLMDecoder is the only LM decoder implemented + Args: + decoder (onmt.decoders.TransformerLMDecoder): a transformer decoder + """ + + def __init__(self, encoder=None, decoder=None): + super(LanguageModel, self).__init__(encoder, decoder) + if encoder is not None: + raise ValueError("LanguageModel should not be used with an encoder") + self.decoder = decoder + + def forward(self, src, tgt, lengths, bptt=False, with_align=False): + """Forward propagate a `src` and `tgt` pair for training. + Possible initialized with a beginning decoder state. + Args: + src (Tensor): A source sequence passed to decoder. + typically for inputs this will be a padded `LongTensor` + of size ``(len, batch, features)``. However, may be an + image or other generic input depending on decoder. + tgt (LongTensor): A target sequence passed to decoder. + Size ``(tgt_len, batch, features)``. + lengths(LongTensor): The src lengths, pre-padding ``(batch,)``. + bptt (Boolean): A flag indicating if truncated bptt is set. + If reset then init_state + with_align (Boolean): A flag indicating whether output alignment, + Only valid for transformer decoder. + Returns: + (FloatTensor, dict[str, FloatTensor]): + * decoder output ``(tgt_len, batch, hidden)`` + * dictionary attention dists of ``(tgt_len, batch, src_len)`` + """ + + if not bptt: + self.decoder.init_state() + dec_out, attns = self.decoder(src, memory_bank=None, memory_lengths=lengths, with_align=with_align) + return dec_out, attns + + def update_dropout(self, dropout): + self.decoder.update_dropout(dropout) + + def count_parameters(self, log=print): + """Count number of parameters in model (& print with `log` callback). + Returns: + (int, int): + * encoder side parameter count + * decoder side parameter count + """ + + enc, dec = 0, 0 + for name, param in self.named_parameters(): + if "decoder" in name: + dec += param.nelement() + + if callable(log): + # No encoder in LM, seq2seq count formatting kept + log("total encoder parameters: {}".format(enc)) + log("total decoder parameters: {}".format(dec)) + log("* number of parameters: {}".format(enc + dec)) + return enc, dec diff --git a/mammoth/models/model_saver.py b/onmt/models/model_saver.py similarity index 91% rename from mammoth/models/model_saver.py rename to onmt/models/model_saver.py index 7ece5078..44ea45f4 100644 --- a/mammoth/models/model_saver.py +++ b/onmt/models/model_saver.py @@ -1,20 +1,20 @@ import os from collections import deque -from mammoth.utils.logging import logger +from onmt.utils.logging import logger import torch import torch.nn as nn -from mammoth.utils.module_splitter import explode_model +from onmt.utils.module_splitter import explode_model -def build_model_saver(model_opts, opts, model, vocabs_dict, optim, device_context): +def build_model_saver(model_opt, opt, model, vocabs_dict, optim, device_context): # _check_save_model_path - save_model_path = os.path.abspath(opts.save_model) + save_model_path = os.path.abspath(opt.save_model) os.makedirs(os.path.dirname(save_model_path), exist_ok=True) model_saver = ModelSaver( - opts.save_model, model, model_opts, vocabs_dict, optim, opts.keep_checkpoint, device_context, opts.save_all_gpus + opt.save_model, model, model_opt, vocabs_dict, optim, opt.keep_checkpoint, device_context, opt.save_all_gpus ) return model_saver @@ -40,7 +40,7 @@ def __init__( self, base_path, model, - model_opts, + model_opt, vocabs_dict, optim, keep_checkpoint=-1, @@ -49,7 +49,7 @@ def __init__( ): self.base_path = base_path self.model = model - self.model_opts = model_opts + self.model_opt = model_opt self.vocabs_dict = vocabs_dict self.optim = optim self.last_saved_step = None @@ -129,7 +129,7 @@ def _save(self, step, model, device_context): "model": model_state_dict, # 'generator': generator_state_dict, "vocab": self.vocabs_dict, - "opts": self.model_opts, + "opt": self.model_opt, "optim": {k: v.state_dict() for k, v in self.optim._optimizer.optimizers.items()}, "whole_model": self.model, } @@ -159,7 +159,7 @@ def _save(self, step, model, device_context): tmp_checkpoint_paths.append(checkpoint_path) if device_context.is_master(): - # TODO: not sure how to deal with model_state_dict, fields, model_opts and optim.state_dict() in a multi-gpu + # TODO: not sure how to deal with model_state_dict, fields, model_opt and optim.state_dict() in a multi-gpu # setting. Is it OK to save only from master? # model frame diff --git a/onmt/models/sru.py b/onmt/models/sru.py new file mode 100644 index 00000000..4df30ef0 --- /dev/null +++ b/onmt/models/sru.py @@ -0,0 +1,647 @@ +""" SRU Implementation """ +# flake8: noqa + +import subprocess +import platform +import os +import re +import configargparse +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.cuda.amp import custom_fwd, custom_bwd +from collections import namedtuple + + +# For command-line option parsing +class CheckSRU(configargparse.Action): + def __init__(self, option_strings, dest, **kwargs): + super(CheckSRU, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + if values == 'SRU': + check_sru_requirement(abort=True) + # Check pass, set the args. + setattr(namespace, self.dest, values) + + +# This SRU version implements its own cuda-level optimization, +# so it requires that: +# 1. `cupy` and `pynvrtc` python package installed. +# 2. pytorch is built with cuda support. +# 3. library path set: export LD_LIBRARY_PATH=. +def check_sru_requirement(abort=False): + """ + Return True if check pass; if check fails and abort is True, + raise an Exception, othereise return False. + """ + + # Check 1. + try: + if platform.system() == 'Windows': + subprocess.check_output('pip freeze | findstr cupy', shell=True) + subprocess.check_output('pip freeze | findstr pynvrtc', shell=True) + else: # Unix-like systems + subprocess.check_output('pip freeze | grep -w cupy', shell=True) + subprocess.check_output('pip freeze | grep -w pynvrtc', shell=True) + except subprocess.CalledProcessError: + if not abort: + return False + raise AssertionError("Using SRU requires 'cupy' and 'pynvrtc' python packages installed.") + + # Check 2. + if torch.cuda.is_available() is False: + if not abort: + return False + raise AssertionError("Using SRU requires pytorch built with cuda.") + + # Check 3. + pattern = re.compile(".*cuda/lib.*") + ld_path = os.getenv('LD_LIBRARY_PATH', "") + if re.match(pattern, ld_path) is None: + if not abort: + return False + raise AssertionError( + "Using SRU requires setting cuda lib path, e.g. export LD_LIBRARY_PATH=/usr/local/cuda/lib64." + ) + + return True + + +SRU_CODE = """ +extern "C" { + __forceinline__ __device__ float sigmoidf(float x) + { + return 1.f / (1.f + expf(-x)); + } + __forceinline__ __device__ float reluf(float x) + { + return (x > 0.f) ? x : 0.f; + } + __global__ void sru_fwd(const float * __restrict__ u, + const float * __restrict__ x, + const float * __restrict__ bias, + const float * __restrict__ init, + const float * __restrict__ mask_h, + const int len, const int batch, + const int d, const int k, + float * __restrict__ h, + float * __restrict__ c, + const int activation_type) + { + assert ((k == 3) || (x == NULL)); + int ncols = batch*d; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + int ncols_u = ncols*k; + int ncols_x = (k == 3) ? ncols : ncols_u; + const float bias1 = *(bias + (col%d)); + const float bias2 = *(bias + (col%d) + d); + const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); + float cur = *(init + col); + const float *up = u + (col*k); + const float *xp = (k == 3) ? (x + col) : (up + 3); + float *cp = c + col; + float *hp = h + col; + for (int row = 0; row < len; ++row) + { + float g1 = sigmoidf((*(up+1))+bias1); + float g2 = sigmoidf((*(up+2))+bias2); + cur = (cur-(*up))*g1 + (*up); + *cp = cur; + float val = (activation_type == 1) ? tanh(cur) : ( + (activation_type == 2) ? reluf(cur) : cur + ); + *hp = (val*mask-(*xp))*g2 + (*xp); + up += ncols_u; + xp += ncols_x; + cp += ncols; + hp += ncols; + } + } + __global__ void sru_bwd(const float * __restrict__ u, + const float * __restrict__ x, + const float * __restrict__ bias, + const float * __restrict__ init, + const float * __restrict__ mask_h, + const float * __restrict__ c, + const float * __restrict__ grad_h, + const float * __restrict__ grad_last, + const int len, + const int batch, const int d, const int k, + float * __restrict__ grad_u, + float * __restrict__ grad_x, + float * __restrict__ grad_bias, + float * __restrict__ grad_init, + int activation_type) + { + assert((k == 3) || (x == NULL)); + assert((k == 3) || (grad_x == NULL)); + int ncols = batch*d; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + int ncols_u = ncols*k; + int ncols_x = (k == 3) ? ncols : ncols_u; + const float bias1 = *(bias + (col%d)); + const float bias2 = *(bias + (col%d) + d); + const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); + float gbias1 = 0; + float gbias2 = 0; + float cur = *(grad_last + col); + const float *up = u + (col*k) + (len-1)*ncols_u; + const float *xp = (k == 3) ? (x + col + (len-1)*ncols) : (up + 3); + const float *cp = c + col + (len-1)*ncols; + const float *ghp = grad_h + col + (len-1)*ncols; + float *gup = grad_u + (col*k) + (len-1)*ncols_u; + float *gxp = (k == 3) ? (grad_x + col + (len-1)*ncols) : (gup + 3); + for (int row = len-1; row >= 0; --row) + { + const float g1 = sigmoidf((*(up+1))+bias1); + const float g2 = sigmoidf((*(up+2))+bias2); + const float c_val = (activation_type == 1) ? tanh(*cp) : ( + (activation_type == 2) ? reluf(*cp) : (*cp) + ); + const float x_val = *xp; + const float u_val = *up; + const float prev_c_val = (row>0) ? (*(cp-ncols)) : (*(init+col)); + const float gh_val = *ghp; + // h = c*g2 + x*(1-g2) = (c-x)*g2 + x + // c = c'*g1 + g0*(1-g1) = (c'-g0)*g1 + g0 + // grad wrt x + *gxp = gh_val*(1-g2); + // grad wrt g2, u2 and bias2 + float gg2 = gh_val*(c_val*mask-x_val)*(g2*(1-g2)); + *(gup+2) = gg2; + gbias2 += gg2; + // grad wrt c + const float tmp = (activation_type == 1) ? (g2*(1-c_val*c_val)) : ( + ((activation_type == 0) || (c_val > 0)) ? g2 : 0.f + ); + const float gc = gh_val*mask*tmp + cur; + // grad wrt u0 + *gup = gc*(1-g1); + // grad wrt g1, u1, and bias1 + float gg1 = gc*(prev_c_val-u_val)*(g1*(1-g1)); + *(gup+1) = gg1; + gbias1 += gg1; + // grad wrt c' + cur = gc*g1; + up -= ncols_u; + xp -= ncols_x; + cp -= ncols; + gup -= ncols_u; + gxp -= ncols_x; + ghp -= ncols; + } + *(grad_bias + col) = gbias1; + *(grad_bias + col + ncols) = gbias2; + *(grad_init +col) = cur; + } + __global__ void sru_bi_fwd(const float * __restrict__ u, + const float * __restrict__ x, + const float * __restrict__ bias, + const float * __restrict__ init, + const float * __restrict__ mask_h, + const int len, const int batch, + const int d, const int k, + float * __restrict__ h, + float * __restrict__ c, + const int activation_type) + { + assert ((k == 3) || (x == NULL)); + assert ((k == 3) || (k == 4)); + int ncols = batch*d*2; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + int ncols_u = ncols*k; + int ncols_x = (k == 3) ? ncols : ncols_u; + const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); + float cur = *(init + col); + const int d2 = d*2; + const bool flip = (col%d2) >= d; + const float bias1 = *(bias + (col%d2)); + const float bias2 = *(bias + (col%d2) + d2); + const float *up = u + (col*k); + const float *xp = (k == 3) ? (x + col) : (up + 3); + float *cp = c + col; + float *hp = h + col; + if (flip) { + up += (len-1)*ncols_u; + xp += (len-1)*ncols_x; + cp += (len-1)*ncols; + hp += (len-1)*ncols; + } + int ncols_u_ = flip ? -ncols_u : ncols_u; + int ncols_x_ = flip ? -ncols_x : ncols_x; + int ncols_ = flip ? -ncols : ncols; + for (int cnt = 0; cnt < len; ++cnt) + { + float g1 = sigmoidf((*(up+1))+bias1); + float g2 = sigmoidf((*(up+2))+bias2); + cur = (cur-(*up))*g1 + (*up); + *cp = cur; + float val = (activation_type == 1) ? tanh(cur) : ( + (activation_type == 2) ? reluf(cur) : cur + ); + *hp = (val*mask-(*xp))*g2 + (*xp); + up += ncols_u_; + xp += ncols_x_; + cp += ncols_; + hp += ncols_; + } + } + __global__ void sru_bi_bwd(const float * __restrict__ u, + const float * __restrict__ x, + const float * __restrict__ bias, + const float * __restrict__ init, + const float * __restrict__ mask_h, + const float * __restrict__ c, + const float * __restrict__ grad_h, + const float * __restrict__ grad_last, + const int len, const int batch, + const int d, const int k, + float * __restrict__ grad_u, + float * __restrict__ grad_x, + float * __restrict__ grad_bias, + float * __restrict__ grad_init, + int activation_type) + { + assert((k == 3) || (x == NULL)); + assert((k == 3) || (grad_x == NULL)); + assert((k == 3) || (k == 4)); + int ncols = batch*d*2; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= ncols) return; + int ncols_u = ncols*k; + int ncols_x = (k == 3) ? ncols : ncols_u; + const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); + float gbias1 = 0; + float gbias2 = 0; + float cur = *(grad_last + col); + const int d2 = d*2; + const bool flip = ((col%d2) >= d); + const float bias1 = *(bias + (col%d2)); + const float bias2 = *(bias + (col%d2) + d2); + const float *up = u + (col*k); + const float *xp = (k == 3) ? (x + col) : (up + 3); + const float *cp = c + col; + const float *ghp = grad_h + col; + float *gup = grad_u + (col*k); + float *gxp = (k == 3) ? (grad_x + col) : (gup + 3); + if (!flip) { + up += (len-1)*ncols_u; + xp += (len-1)*ncols_x; + cp += (len-1)*ncols; + ghp += (len-1)*ncols; + gup += (len-1)*ncols_u; + gxp += (len-1)*ncols_x; + } + int ncols_u_ = flip ? -ncols_u : ncols_u; + int ncols_x_ = flip ? -ncols_x : ncols_x; + int ncols_ = flip ? -ncols : ncols; + for (int cnt = 0; cnt < len; ++cnt) + { + const float g1 = sigmoidf((*(up+1))+bias1); + const float g2 = sigmoidf((*(up+2))+bias2); + const float c_val = (activation_type == 1) ? tanh(*cp) : ( + (activation_type == 2) ? reluf(*cp) : (*cp) + ); + const float x_val = *xp; + const float u_val = *up; + const float prev_c_val = (cnt 0)) ? g2 : 0.f + ); + const float gc = gh_val*mask*tmp + cur; + // grad wrt u0 + *gup = gc*(1-g1); + // grad wrt g1, u1, and bias1 + float gg1 = gc*(prev_c_val-u_val)*(g1*(1-g1)); + *(gup+1) = gg1; + gbias1 += gg1; + // grad wrt c' + cur = gc*g1; + up -= ncols_u_; + xp -= ncols_x_; + cp -= ncols_; + gup -= ncols_u_; + gxp -= ncols_x_; + ghp -= ncols_; + } + *(grad_bias + col) = gbias1; + *(grad_bias + col + ncols) = gbias2; + *(grad_init +col) = cur; + } +} +""" +SRU_FWD_FUNC, SRU_BWD_FUNC = None, None +SRU_BiFWD_FUNC, SRU_BiBWD_FUNC = None, None +SRU_STREAM = None + + +def load_sru_mod(): + global SRU_FWD_FUNC, SRU_BWD_FUNC, SRU_BiFWD_FUNC, SRU_BiBWD_FUNC + global SRU_STREAM + if check_sru_requirement(): + from cupy.cuda import function + from pynvrtc.compiler import Program + + # This sets up device to use. + device = torch.device("cuda") + tmp_ = torch.rand(1, 1).to(device) + + sru_prog = Program(SRU_CODE.encode('utf-8'), 'sru_prog.cu'.encode('utf-8')) + sru_ptx = sru_prog.compile() + sru_mod = function.Module() + sru_mod.load(bytes(sru_ptx.encode())) + + SRU_FWD_FUNC = sru_mod.get_function('sru_fwd') + SRU_BWD_FUNC = sru_mod.get_function('sru_bwd') + SRU_BiFWD_FUNC = sru_mod.get_function('sru_bi_fwd') + SRU_BiBWD_FUNC = sru_mod.get_function('sru_bi_bwd') + + stream = namedtuple('Stream', ['ptr']) + SRU_STREAM = stream(ptr=torch.cuda.current_stream().cuda_stream) + + +class SRU_Compute(Function): + def __init__(self, activation_type, d_out, bidirectional=False): + SRU_Compute.maybe_load_sru_mod() + super(SRU_Compute, self).__init__() + self.activation_type = activation_type + self.d_out = d_out + self.bidirectional = bidirectional + + @staticmethod + def maybe_load_sru_mod(): + global SRU_FWD_FUNC + + if SRU_FWD_FUNC is None: + load_sru_mod() + + @custom_fwd + def forward(self, u, x, bias, init=None, mask_h=None): + bidir = 2 if self.bidirectional else 1 + length = x.size(0) if x.dim() == 3 else 1 + batch = x.size(-2) + d = self.d_out + k = u.size(-1) // d + k_ = k // 2 if self.bidirectional else k + ncols = batch * d * bidir + thread_per_block = min(512, ncols) + num_block = (ncols - 1) // thread_per_block + 1 + + init_ = x.new(ncols).zero_() if init is None else init + size = (length, batch, d * bidir) if x.dim() == 3 else (batch, d * bidir) + c = x.new(*size) + h = x.new(*size) + + FUNC = SRU_FWD_FUNC if not self.bidirectional else SRU_BiFWD_FUNC + FUNC( + args=[ + u.contiguous().data_ptr(), + x.contiguous().data_ptr() if k_ == 3 else 0, + bias.data_ptr(), + init_.contiguous().data_ptr(), + mask_h.data_ptr() if mask_h is not None else 0, + length, + batch, + d, + k_, + h.data_ptr(), + c.data_ptr(), + self.activation_type, + ], + block=(thread_per_block, 1, 1), + grid=(num_block, 1, 1), + stream=SRU_STREAM, + ) + + self.save_for_backward(u, x, bias, init, mask_h) + self.intermediate = c + if x.dim() == 2: + last_hidden = c + elif self.bidirectional: + # -> directions x batch x dim + last_hidden = torch.stack((c[-1, :, :d], c[0, :, d:])) + else: + last_hidden = c[-1] + return h, last_hidden + + @custom_bwd + def backward(self, grad_h, grad_last): + if self.bidirectional: + grad_last = torch.cat((grad_last[0], grad_last[1]), 1) + bidir = 2 if self.bidirectional else 1 + u, x, bias, init, mask_h = self.saved_tensors + c = self.intermediate + length = x.size(0) if x.dim() == 3 else 1 + batch = x.size(-2) + d = self.d_out + k = u.size(-1) // d + k_ = k // 2 if self.bidirectional else k + ncols = batch * d * bidir + thread_per_block = min(512, ncols) + num_block = (ncols - 1) // thread_per_block + 1 + + init_ = x.new(ncols).zero_() if init is None else init + grad_u = u.new(*u.size()) + grad_bias = x.new(2, batch, d * bidir) + grad_init = x.new(batch, d * bidir) + + # For DEBUG + # size = (length, batch, x.size(-1)) \ + # if x.dim() == 3 else (batch, x.size(-1)) + # grad_x = x.new(*x.size()) if k_ == 3 else x.new(*size).zero_() + + # Normal use + grad_x = x.new(*x.size()) if k_ == 3 else None + + FUNC = SRU_BWD_FUNC if not self.bidirectional else SRU_BiBWD_FUNC + FUNC( + args=[ + u.contiguous().data_ptr(), + x.contiguous().data_ptr() if k_ == 3 else 0, + bias.data_ptr(), + init_.contiguous().data_ptr(), + mask_h.data_ptr() if mask_h is not None else 0, + c.data_ptr(), + grad_h.contiguous().data_ptr(), + grad_last.contiguous().data_ptr(), + length, + batch, + d, + k_, + grad_u.data_ptr(), + grad_x.data_ptr() if k_ == 3 else 0, + grad_bias.data_ptr(), + grad_init.data_ptr(), + self.activation_type, + ], + block=(thread_per_block, 1, 1), + grid=(num_block, 1, 1), + stream=SRU_STREAM, + ) + return grad_u, grad_x, grad_bias.sum(1).view(-1), grad_init, None + + +class SRUCell(nn.Module): + def __init__(self, n_in, n_out, dropout=0, rnn_dropout=0, bidirectional=False, use_tanh=1, use_relu=0): + super(SRUCell, self).__init__() + self.n_in = n_in + self.n_out = n_out + self.rnn_dropout = rnn_dropout + self.dropout = dropout + self.bidirectional = bidirectional + self.activation_type = 2 if use_relu else (1 if use_tanh else 0) + + out_size = n_out * 2 if bidirectional else n_out + k = 4 if n_in != out_size else 3 + self.size_per_dir = n_out * k + self.weight = nn.Parameter(torch.Tensor(n_in, self.size_per_dir * 2 if bidirectional else self.size_per_dir)) + self.bias = nn.Parameter(torch.Tensor(n_out * 4 if bidirectional else n_out * 2)) + self.init_weight() + + def init_weight(self): + val_range = (3.0 / self.n_in) ** 0.5 + self.weight.data.uniform_(-val_range, val_range) + self.bias.data.zero_() + + def set_bias(self, bias_val=0): + n_out = self.n_out + if self.bidirectional: + self.bias.data[n_out * 2 :].zero_().add_(bias_val) + else: + self.bias.data[n_out:].zero_().add_(bias_val) + + def forward(self, input, c0=None): + assert input.dim() == 2 or input.dim() == 3 + n_in, n_out = self.n_in, self.n_out + batch = input.size(-2) + if c0 is None: + c0 = input.data.new(batch, n_out if not self.bidirectional else n_out * 2).zero_() + + if self.training and (self.rnn_dropout > 0): + mask = self.get_dropout_mask_((batch, n_in), self.rnn_dropout) + x = input * mask.expand_as(input) + else: + x = input + + x_2d = x if x.dim() == 2 else x.contiguous().view(-1, n_in) + u = x_2d.mm(self.weight) + + if self.training and (self.dropout > 0): + bidir = 2 if self.bidirectional else 1 + mask_h = self.get_dropout_mask_((batch, n_out * bidir), self.dropout) + h, c = SRU_Compute(self.activation_type, n_out, self.bidirectional)(u, input, self.bias, c0, mask_h) + else: + h, c = SRU_Compute(self.activation_type, n_out, self.bidirectional)(u, input, self.bias, c0) + + return h, c + + def get_dropout_mask_(self, size, p): + w = self.weight.data + return w.new(*size).bernoulli_(1 - p).div_(1 - p) + + +class SRU(nn.Module): + """ + Implementation of "Training RNNs as Fast as CNNs" + :cite:`DBLP:journals/corr/abs-1709-02755` + + TODO: turn to pytorch's implementation when it is available. + + This implementation is adpoted from the author of the paper: + https://github.com/taolei87/sru/blob/master/cuda_functional.py. + + Args: + input_size (int): input to model + hidden_size (int): hidden dimension + num_layers (int): number of layers + dropout (float): dropout to use (stacked) + rnn_dropout (float): dropout to use (recurrent) + bidirectional (bool): bidirectional + use_tanh (bool): activation + use_relu (bool): activation + """ + + def __init__( + self, + input_size, + hidden_size, + num_layers=2, + dropout=0, + rnn_dropout=0, + bidirectional=False, + use_tanh=1, + use_relu=0, + ): + # An entry check here, will catch on train side and translate side + # if requirements are not satisfied. + check_sru_requirement(abort=True) + super(SRU, self).__init__() + self.n_in = input_size + self.n_out = hidden_size + self.depth = num_layers + self.dropout = dropout + self.rnn_dropout = rnn_dropout + self.rnn_lst = nn.ModuleList() + self.bidirectional = bidirectional + self.out_size = hidden_size * 2 if bidirectional else hidden_size + + for i in range(num_layers): + sru_cell = SRUCell( + n_in=self.n_in if i == 0 else self.out_size, + n_out=self.n_out, + dropout=dropout if i + 1 != num_layers else 0, + rnn_dropout=rnn_dropout, + bidirectional=bidirectional, + use_tanh=use_tanh, + use_relu=use_relu, + ) + self.rnn_lst.append(sru_cell) + + def set_bias(self, bias_val=0): + for l in self.rnn_lst: + l.set_bias(bias_val) + + def forward(self, input, c0=None, return_hidden=True): + assert input.dim() == 3 # (len, batch, n_in) + dir_ = 2 if self.bidirectional else 1 + if c0 is None: + zeros = input.data.new(input.size(1), self.n_out * dir_).zero_() + c0 = [zeros for i in range(self.depth)] + else: + if isinstance(c0, tuple): + # RNNDecoderState wraps hidden as a tuple. + c0 = c0[0] + assert c0.dim() == 3 # (depth, batch, dir_*n_out) + c0 = [h.squeeze(0) for h in c0.chunk(self.depth, 0)] + + prevx = input + lstc = [] + for i, rnn in enumerate(self.rnn_lst): + h, c = rnn(prevx, c0[i]) + prevx = h + lstc.append(c) + + if self.bidirectional: + # fh -> (layers*directions) x batch x dim + fh = torch.cat(lstc) + else: + fh = torch.stack(lstc) + + if return_hidden: + return prevx, fh + else: + return prevx diff --git a/onmt/models/stacked_rnn.py b/onmt/models/stacked_rnn.py new file mode 100644 index 00000000..cb201f04 --- /dev/null +++ b/onmt/models/stacked_rnn.py @@ -0,0 +1,65 @@ +""" Implementation of ONMT RNN for Input Feeding Decoding """ +import torch +import torch.nn as nn + + +class StackedLSTM(nn.Module): + """ + Our own implementation of stacked LSTM. + Needed for the decoder, because we do input feeding. + """ + + def __init__(self, num_layers, input_size, rnn_size, dropout): + super(StackedLSTM, self).__init__() + self.dropout = nn.Dropout(dropout) + self.num_layers = num_layers + self.layers = nn.ModuleList() + + for _ in range(num_layers): + self.layers.append(nn.LSTMCell(input_size, rnn_size)) + input_size = rnn_size + + def forward(self, input_feed, hidden): + h_0, c_0 = hidden + h_1, c_1 = [], [] + for i, layer in enumerate(self.layers): + h_1_i, c_1_i = layer(input_feed, (h_0[i], c_0[i])) + input_feed = h_1_i + if i + 1 != self.num_layers: + input_feed = self.dropout(input_feed) + h_1 += [h_1_i] + c_1 += [c_1_i] + + h_1 = torch.stack(h_1) + c_1 = torch.stack(c_1) + + return input_feed, (h_1, c_1) + + +class StackedGRU(nn.Module): + """ + Our own implementation of stacked GRU. + Needed for the decoder, because we do input feeding. + """ + + def __init__(self, num_layers, input_size, rnn_size, dropout): + super(StackedGRU, self).__init__() + self.dropout = nn.Dropout(dropout) + self.num_layers = num_layers + self.layers = nn.ModuleList() + + for _ in range(num_layers): + self.layers.append(nn.GRUCell(input_size, rnn_size)) + input_size = rnn_size + + def forward(self, input_feed, hidden): + h_1 = [] + for i, layer in enumerate(self.layers): + h_1_i = layer(input_feed, hidden[0][i]) + input_feed = h_1_i + if i + 1 != self.num_layers: + input_feed = self.dropout(input_feed) + h_1 += [h_1_i] + + h_1 = torch.stack(h_1) + return input_feed, (h_1,) diff --git a/onmt/modules/__init__.py b/onmt/modules/__init__.py new file mode 100644 index 00000000..45122e53 --- /dev/null +++ b/onmt/modules/__init__.py @@ -0,0 +1,34 @@ +""" Attention and normalization modules """ +from onmt.modules.util_class import Elementwise +from onmt.modules.gate import context_gate_factory, ContextGate +from onmt.modules.global_attention import GlobalAttention +from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention +from onmt.modules.copy_generator import ( + CopyGenerator, + CopyGeneratorLoss, + CopyGeneratorLossCompute, + CopyGeneratorLMLossCompute, +) +from onmt.modules.multi_headed_attn import MultiHeadedAttention +from onmt.modules.embeddings import Embeddings, PositionalEncoding +from onmt.modules.weight_norm import WeightNormConv2d +from onmt.modules.average_attn import AverageAttention +from onmt.modules.stable_embeddings import StableEmbedding + +__all__ = [ + "Elementwise", + "context_gate_factory", + "ContextGate", + "GlobalAttention", + "ConvMultiStepAttention", + "CopyGenerator", + "CopyGeneratorLoss", + "CopyGeneratorLossCompute", + "MultiHeadedAttention", + "Embeddings", + "PositionalEncoding", + "WeightNormConv2d", + "AverageAttention", + "CopyGeneratorLMLossCompute", + "StableEmbedding", +] diff --git a/mammoth/modules/average_attn.py b/onmt/modules/average_attn.py similarity index 97% rename from mammoth/modules/average_attn.py rename to onmt/modules/average_attn.py index 24c8eb23..04fec83f 100644 --- a/mammoth/modules/average_attn.py +++ b/onmt/modules/average_attn.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn -from mammoth.modules.position_ffn import PositionwiseFeedForward -from mammoth.modules.position_ffn import ActivationFunction +from onmt.modules.position_ffn import PositionwiseFeedForward +from onmt.modules.position_ffn import ActivationFunction class AverageAttention(nn.Module): diff --git a/onmt/modules/conv_multi_step_attention.py b/onmt/modules/conv_multi_step_attention.py new file mode 100644 index 00000000..fe1fd4b4 --- /dev/null +++ b/onmt/modules/conv_multi_step_attention.py @@ -0,0 +1,76 @@ +""" Multi Step Attention for CNN """ +import torch +import torch.nn as nn +import torch.nn.functional as F +from onmt.utils.misc import aeq + + +SCALE_WEIGHT = 0.5**0.5 + + +def seq_linear(linear, x): + """linear transform for 3-d tensor""" + batch, hidden_size, length, _ = x.size() + h = linear(torch.transpose(x, 1, 2).contiguous().view(batch * length, hidden_size)) + return torch.transpose(h.view(batch, length, hidden_size, 1), 1, 2) + + +class ConvMultiStepAttention(nn.Module): + """ + Conv attention takes a key matrix, a value matrix and a query vector. + Attention weight is calculated by key matrix with the query vector + and sum on the value matrix. And the same operation is applied + in each decode conv layer. + """ + + def __init__(self, input_size): + super(ConvMultiStepAttention, self).__init__() + self.linear_in = nn.Linear(input_size, input_size) + self.mask = None + + def apply_mask(self, mask): + """Apply mask""" + self.mask = mask + + def forward(self, base_target_emb, input_from_dec, encoder_out_top, encoder_out_combine): + """ + Args: + base_target_emb: target emb tensor + input_from_dec: output of decode conv + encoder_out_top: the key matrix for calculation of attetion weight, + which is the top output of encode conv + encoder_out_combine: + the value matrix for the attention-weighted sum, + which is the combination of base emb and top output of encode + """ + + # checks + # batch, channel, height, width = base_target_emb.size() + batch, _, height, _ = base_target_emb.size() + # batch_, channel_, height_, width_ = input_from_dec.size() + batch_, _, height_, _ = input_from_dec.size() + aeq(batch, batch_) + aeq(height, height_) + + # enc_batch, enc_channel, enc_height = encoder_out_top.size() + enc_batch, _, enc_height = encoder_out_top.size() + # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() + enc_batch_, _, enc_height_ = encoder_out_combine.size() + + aeq(enc_batch, enc_batch_) + aeq(enc_height, enc_height_) + + preatt = seq_linear(self.linear_in, input_from_dec) + target = (base_target_emb + preatt) * SCALE_WEIGHT + target = torch.squeeze(target, 3) + target = torch.transpose(target, 1, 2) + pre_attn = torch.bmm(target, encoder_out_top) + + if self.mask is not None: + pre_attn.data.masked_fill_(self.mask, -float('inf')) + + attn = F.softmax(pre_attn, dim=2) + + context_output = torch.bmm(attn, torch.transpose(encoder_out_combine, 1, 2)) + context_output = torch.transpose(torch.unsqueeze(context_output, 3), 1, 2) + return context_output, attn diff --git a/onmt/modules/copy_generator.py b/onmt/modules/copy_generator.py new file mode 100644 index 00000000..3e426119 --- /dev/null +++ b/onmt/modules/copy_generator.py @@ -0,0 +1,264 @@ +import torch +import torch.nn as nn + +from onmt.utils.misc import aeq +from onmt.utils.loss import CommonLossCompute + + +def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs=None, batch_dim=1, batch_offset=None): + """ + Given scores from an expanded dictionary + corresponeding to a batch, sums together copies, + with a dictionary word when it is ambiguous. + """ + offset = len(tgt_vocab) + for b in range(scores.size(batch_dim)): + blank = [] + fill = [] + + if src_vocabs is None: + src_vocab = batch.src_ex_vocab[b] + else: + batch_id = batch_offset[b] if batch_offset is not None else b + index = batch.indices.data[batch_id] + src_vocab = src_vocabs[index] + + for i in range(1, len(src_vocab)): + sw = src_vocab.itos[i] + ti = tgt_vocab.stoi[sw] + if ti != 0: + blank.append(offset + i) + fill.append(ti) + if blank: + blank = torch.Tensor(blank).type_as(batch.indices.data) + fill = torch.Tensor(fill).type_as(batch.indices.data) + score = scores[:, b] if batch_dim == 1 else scores[b] + score.index_add_(1, fill, score.index_select(1, blank)) + score.index_fill_(1, blank, 1e-10) + return scores + + +class CopyGenerator(nn.Module): + """An implementation of pointer-generator networks + :cite:`DBLP:journals/corr/SeeLM17`. + + These networks consider copying words + directly from the source sequence. + + The copy generator is an extended version of the standard + generator that computes three values. + + * :math:`p_{softmax}` the standard softmax over `tgt_dict` + * :math:`p(z)` the probability of copying a word from + the source + * :math:`p_{copy}` the probility of copying a particular word. + taken from the attention distribution directly. + + The model returns a distribution over the extend dictionary, + computed as + + :math:`p(w) = p(z=1) p_{copy}(w) + p(z=0) p_{softmax}(w)` + + + .. mermaid:: + + graph BT + A[input] + S[src_map] + B[softmax] + BB[switch] + C[attn] + D[copy] + O[output] + A --> B + A --> BB + S --> D + C --> D + D --> O + B --> O + BB --> O + + + Args: + input_size (int): size of input representation + output_size (int): size of output vocabulary + pad_idx (int) + """ + + def __init__(self, input_size, output_size, pad_idx): + super(CopyGenerator, self).__init__() + self.linear = nn.Linear(input_size, output_size) + self.linear_copy = nn.Linear(input_size, 1) + self.pad_idx = pad_idx + + def forward(self, hidden, attn, src_map): + """ + Compute a distribution over the target dictionary + extended by the dynamic dictionary implied by copying + source words. + + Args: + hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` + attn (FloatTensor): attn for each ``(batch x tlen, slen)`` + src_map (FloatTensor): + A sparse indicator matrix mapping each source word to + its index in the "extended" vocab containing. + ``(src_len, batch, extra_words)`` + """ + + # CHECKS + batch_by_tlen, _ = hidden.size() + batch_by_tlen_, slen = attn.size() + slen_, batch, cvocab = src_map.size() + aeq(batch_by_tlen, batch_by_tlen_) + aeq(slen, slen_) + + # Original probabilities. + logits = self.linear(hidden) + logits[:, self.pad_idx] = -float('inf') + prob = torch.softmax(logits, 1) + + # Probability of copying p(z=1) batch. + p_copy = torch.sigmoid(self.linear_copy(hidden)) + # Probability of not copying: p_{word}(w) * (1 - p(z)) + out_prob = torch.mul(prob, 1 - p_copy) + mul_attn = torch.mul(attn, p_copy) + copy_prob = torch.bmm(mul_attn.view(-1, batch, slen).transpose(0, 1), src_map.transpose(0, 1)).transpose(0, 1) + copy_prob = copy_prob.contiguous().view(-1, cvocab) + return torch.cat([out_prob, copy_prob], 1) + + +class CopyGeneratorLoss(nn.Module): + """Copy generator criterion.""" + + def __init__(self, vocab_size, force_copy, unk_index=0, ignore_index=-100, eps=1e-20): + super(CopyGeneratorLoss, self).__init__() + self.force_copy = force_copy + self.eps = eps + self.vocab_size = vocab_size + self.ignore_index = ignore_index + self.unk_index = unk_index + + def forward(self, scores, align, target): + """ + Args: + scores (FloatTensor): ``(batch_size*tgt_len)`` x dynamic vocab size + whose sum along dim 1 is less than or equal to 1, i.e. cols + softmaxed. + align (LongTensor): ``(batch_size x tgt_len)`` + target (LongTensor): ``(batch_size x tgt_len)`` + """ + # probabilities assigned by the model to the gold targets + vocab_probs = scores.gather(1, target.unsqueeze(1)).squeeze(1) + + # probability of tokens copied from source + copy_ix = align.unsqueeze(1) + self.vocab_size + copy_tok_probs = scores.gather(1, copy_ix).squeeze(1) + # Set scores for unk to 0 and add eps + copy_tok_probs[align == self.unk_index] = 0 + copy_tok_probs += self.eps # to avoid -inf logs + + # find the indices in which you do not use the copy mechanism + non_copy = align == self.unk_index + if not self.force_copy: + non_copy = non_copy | (target != self.unk_index) + + probs = torch.where(non_copy, copy_tok_probs + vocab_probs, copy_tok_probs) + + loss = -probs.log() # just NLLLoss; can the module be incorporated? + # Drop padding. + loss[target == self.ignore_index] = 0 + return loss + + +class CommonCopyGeneratorLossCompute(CommonLossCompute): + """Common Copy Generator Loss Computation.""" + + def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, lambda_coverage=0.0, tgt_shift_index=1): + super(CommonCopyGeneratorLossCompute, self).__init__( + criterion, generator, lambda_coverage=lambda_coverage, tgt_shift_index=tgt_shift_index + ) + self.tgt_vocab = tgt_vocab + self.normalize_by_length = normalize_by_length + + def _compute_loss(self, batch, output, target, copy_attn, align, std_attn=None, coverage_attn=None): + """Compute the loss. + + The args must match :func:`self._make_shard_state()`. + + Args: + batch: the current batch. + output: the predict output from the model. + target: the validate target to compare output with. + copy_attn: the copy attention value. + align: the align info. + """ + target = target.view(-1) + align = align.view(-1) + scores = self.generator(self._bottle(output), self._bottle(copy_attn), batch.src_map) + loss = self.criterion(scores, align, target) + + if self.lambda_coverage != 0.0: + coverage_loss = self._compute_coverage_loss(std_attn, coverage_attn) + loss += coverage_loss + + # this block does not depend on the loss value computed above + # and is used only for stats + scores_data = collapse_copy_scores( + self._unbottle(scores.clone(), batch.batch_size), batch, self.tgt_vocab, None + ) + scores_data = self._bottle(scores_data) + + # this block does not depend on the loss value computed above + # and is used only for stats + # Correct target copy token instead of + # tgt[i] = align[i] + len(tgt_vocab) + # for i such that tgt[i] == 0 and align[i] != 0 + target_data = target.clone() + unk = self.criterion.unk_index + correct_mask = (target_data == unk) & (align != unk) + offset_align = align[correct_mask] + len(self.tgt_vocab) + target_data[correct_mask] += offset_align + + # Compute sum of perplexities for stats + stats = self._stats(loss.sum().clone(), scores_data, target_data) + + # this part looks like it belongs in CopyGeneratorLoss + if self.normalize_by_length: + # Compute Loss as NLL divided by seq length + tgt_lens = batch.tgt[:, :, 0].ne(self.padding_idx).sum(0).float() + # Compute Total Loss per sequence in batch + loss = loss.view(-1, batch.batch_size).sum(0) + # Divide by length of each sequence and sum + loss = torch.div(loss, tgt_lens).sum() + else: + loss = loss.sum() + + return loss, stats + + def _make_shard_state(self, batch, output, range_, attns): + """See base class for args description.""" + shard_state = super(CommonCopyGeneratorLossCompute, self)._make_shard_state(batch, output, range_, attns) + + start_range = range_[0] + self.tgt_shift_index + end_range = range_[1] + shard_state.update({"copy_attn": attns.get("copy"), "align": batch.alignment[start_range:end_range]}) + return shard_state + + +class CopyGeneratorLossCompute(CommonCopyGeneratorLossCompute): + """Copy Generator Loss Computation.""" + + def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, lambda_coverage=0.0): + super(CopyGeneratorLossCompute, self).__init__( + criterion, generator, tgt_vocab, normalize_by_length, lambda_coverage=0.0, tgt_shift_index=1 + ) + + +class CopyGeneratorLMLossCompute(CommonCopyGeneratorLossCompute): + """Copy Generator LM Loss Computation.""" + + def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, lambda_coverage=0.0): + super(CopyGeneratorLMLossCompute, self).__init__( + criterion, generator, tgt_vocab, normalize_by_length, lambda_coverage=0.0, tgt_shift_index=0 + ) diff --git a/mammoth/modules/embeddings.py b/onmt/modules/embeddings.py similarity index 89% rename from mammoth/modules/embeddings.py rename to onmt/modules/embeddings.py index dbd16ab1..ce8f7732 100644 --- a/mammoth/modules/embeddings.py +++ b/onmt/modules/embeddings.py @@ -5,10 +5,11 @@ import torch import torch.nn as nn -from mammoth.modules.util_class import Elementwise -# from mammoth.utils.logging import logger +from onmt.modules.util_class import Elementwise +# from onmt.utils.logging import logger # import bitsandbytes as bnb +# from onmt.modules.stable_embeddings import StableEmbedding class SequenceTooLongError(Exception): @@ -65,7 +66,7 @@ def forward(self, emb, step=None): class Embeddings(nn.Module): """Words embeddings for encoder/decoder. - Additionally includes ability to add input features + Additionally includes ability to add sparse input features based on "Linguistic Input Features Improve Neural Machine Translation" :cite:`sennrich2016linguistic`. @@ -91,7 +92,7 @@ class Embeddings(nn.Module): word_vocab_size (int): size of dictionary of embeddings for words. feat_vocab_sizes (List[int], optional): list of size of dictionary of embeddings for each feature. - position_encoding (bool): see :class:`~mammoth.modules.PositionalEncoding` + position_encoding (bool): see :class:`~onmt.modules.PositionalEncoding` feat_merge (string): merge action for the features embeddings: concat, sum or mlp. feat_vec_exponent (float): when using `-feat_merge concat`, feature @@ -115,6 +116,7 @@ def __init__( feat_padding_idx=[], feat_vocab_sizes=[], dropout=0, + sparse=False, freeze_word_vecs=False, ): self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent, feat_vec_size, feat_padding_idx) @@ -145,7 +147,7 @@ def __init__( # The embedding matrix look-up tables. The first look-up table # is for words. Subsequent ones are for features, if any exist. emb_params = zip(vocab_sizes, emb_dims, pad_indices) - embeddings = [nn.Embedding(vocab, dim, padding_idx=pad) for vocab, dim, pad in emb_params] + embeddings = [nn.Embedding(vocab, dim, padding_idx=pad, sparse=sparse) for vocab, dim, pad in emb_params] emb_luts = Elementwise(feat_merge, embeddings) # The final output size of word + feature vectors. This can vary @@ -338,12 +340,12 @@ def convert_to_torch_tensor(word_to_float_list_dict, vocab): return tensor # FIXME: seems it got nuked during the great refactoring of data -# def prepare_pretrained_embeddings(opts, fields): -# if all([opts.both_embeddings is None, opts.src_embeddings is None, opts.tgt_embeddings is None]): +# def prepare_pretrained_embeddings(opt, fields): +# if all([opt.both_embeddings is None, opt.src_embeddings is None, opt.tgt_embeddings is None]): # return # # assert ( -# opts.save_data +# opt.save_data # ), "-save_data is required when using \ # pretrained embeddings." # @@ -356,42 +358,42 @@ def convert_to_torch_tensor(word_to_float_list_dict, vocab): # vocs.append(vocab) # enc_vocab, dec_vocab = vocs # -# skip_lines = 1 if opts.embeddings_type == "word2vec" else 0 -# if opts.both_embeddings is not None: +# skip_lines = 1 if opt.embeddings_type == "word2vec" else 0 +# if opt.both_embeddings is not None: # set_of_src_and_tgt_vocab = set(enc_vocab.stoi.keys()) | set(dec_vocab.stoi.keys()) -# logger.info("Reading encoder and decoder embeddings from {}".format(opts.both_embeddings)) -# src_vectors, total_vec_count = read_embeddings(opts.both_embeddings, skip_lines, set_of_src_and_tgt_vocab) +# logger.info("Reading encoder and decoder embeddings from {}".format(opt.both_embeddings)) +# src_vectors, total_vec_count = read_embeddings(opt.both_embeddings, skip_lines, set_of_src_and_tgt_vocab) # tgt_vectors = src_vectors # logger.info("\tFound {} total vectors in file".format(total_vec_count)) # else: -# if opts.src_embeddings is not None: -# logger.info("Reading encoder embeddings from {}".format(opts.src_embeddings)) -# src_vectors, total_vec_count = read_embeddings(opts.src_embeddings, skip_lines, filter_set=enc_vocab.stoi) +# if opt.src_embeddings is not None: +# logger.info("Reading encoder embeddings from {}".format(opt.src_embeddings)) +# src_vectors, total_vec_count = read_embeddings(opt.src_embeddings, skip_lines, filter_set=enc_vocab.stoi) # logger.info("\tFound {} total vectors in file.".format(total_vec_count)) # else: # src_vectors = None -# if opts.tgt_embeddings is not None: -# logger.info("Reading decoder embeddings from {}".format(opts.tgt_embeddings)) -# tgt_vectors, total_vec_count = read_embeddings(opts.tgt_embeddings, skip_lines, filter_set=dec_vocab.stoi) +# if opt.tgt_embeddings is not None: +# logger.info("Reading decoder embeddings from {}".format(opt.tgt_embeddings)) +# tgt_vectors, total_vec_count = read_embeddings(opt.tgt_embeddings, skip_lines, filter_set=dec_vocab.stoi) # logger.info("\tFound {} total vectors in file".format(total_vec_count)) # else: # tgt_vectors = None # logger.info("After filtering to vectors in vocab:") -# if opts.src_embeddings is not None or opts.both_embeddings is not None: +# if opt.src_embeddings is not None or opt.both_embeddings is not None: # logger.info("\t* enc: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(enc_vocab, src_vectors)) -# if opts.tgt_embeddings is not None or opts.both_embeddings is not None: +# if opt.tgt_embeddings is not None or opt.both_embeddings is not None: # logger.info("\t* dec: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(dec_vocab, tgt_vectors)) # # # Write to file -# enc_output_file = opts.save_data + ".enc_embeddings.pt" -# dec_output_file = opts.save_data + ".dec_embeddings.pt" -# if opts.src_embeddings is not None or opts.both_embeddings is not None: +# enc_output_file = opt.save_data + ".enc_embeddings.pt" +# dec_output_file = opt.save_data + ".dec_embeddings.pt" +# if opt.src_embeddings is not None or opt.both_embeddings is not None: # logger.info("\nSaving encoder embeddings as:\n\t* enc: %s" % enc_output_file) # torch.save(convert_to_torch_tensor(src_vectors, enc_vocab), enc_output_file) -# # set the opts in place -# opts.pre_word_vecs_enc = enc_output_file -# if opts.tgt_embeddings is not None or opts.both_embeddings is not None: +# # set the opt in place +# opt.pre_word_vecs_enc = enc_output_file +# if opt.tgt_embeddings is not None or opt.both_embeddings is not None: # logger.info("\nSaving decoder embeddings as:\n\t* dec: %s" % dec_output_file) # torch.save(convert_to_torch_tensor(tgt_vectors, dec_vocab), dec_output_file) -# # set the opts in place -# opts.pre_word_vecs_dec = dec_output_file +# # set the opt in place +# opt.pre_word_vecs_dec = dec_output_file diff --git a/onmt/modules/gate.py b/onmt/modules/gate.py new file mode 100644 index 00000000..86babaf5 --- /dev/null +++ b/onmt/modules/gate.py @@ -0,0 +1,76 @@ +""" ContextGate module """ +import torch +import torch.nn as nn + + +def context_gate_factory(gate_type, embeddings_size, decoder_size, attention_size, output_size): + """Returns the correct ContextGate class""" + + gate_types = {'source': SourceContextGate, 'target': TargetContextGate, 'both': BothContextGate} + + assert gate_type in gate_types, "Not valid ContextGate type: {0}".format(gate_type) + return gate_types[gate_type](embeddings_size, decoder_size, attention_size, output_size) + + +class ContextGate(nn.Module): + """ + Context gate is a decoder module that takes as input the previous word + embedding, the current decoder state and the attention state, and + produces a gate. + The gate can be used to select the input from the target side context + (decoder state), from the source context (attention state) or both. + """ + + def __init__(self, embeddings_size, decoder_size, attention_size, output_size): + super(ContextGate, self).__init__() + input_size = embeddings_size + decoder_size + attention_size + self.gate = nn.Linear(input_size, output_size, bias=True) + self.sig = nn.Sigmoid() + self.source_proj = nn.Linear(attention_size, output_size) + self.target_proj = nn.Linear(embeddings_size + decoder_size, output_size) + + def forward(self, prev_emb, dec_state, attn_state): + input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1) + z = self.sig(self.gate(input_tensor)) + proj_source = self.source_proj(attn_state) + proj_target = self.target_proj(torch.cat((prev_emb, dec_state), dim=1)) + return z, proj_source, proj_target + + +class SourceContextGate(nn.Module): + """Apply the context gate only to the source context""" + + def __init__(self, embeddings_size, decoder_size, attention_size, output_size): + super(SourceContextGate, self).__init__() + self.context_gate = ContextGate(embeddings_size, decoder_size, attention_size, output_size) + self.tanh = nn.Tanh() + + def forward(self, prev_emb, dec_state, attn_state): + z, source, target = self.context_gate(prev_emb, dec_state, attn_state) + return self.tanh(target + z * source) + + +class TargetContextGate(nn.Module): + """Apply the context gate only to the target context""" + + def __init__(self, embeddings_size, decoder_size, attention_size, output_size): + super(TargetContextGate, self).__init__() + self.context_gate = ContextGate(embeddings_size, decoder_size, attention_size, output_size) + self.tanh = nn.Tanh() + + def forward(self, prev_emb, dec_state, attn_state): + z, source, target = self.context_gate(prev_emb, dec_state, attn_state) + return self.tanh(z * target + source) + + +class BothContextGate(nn.Module): + """Apply the context gate to both contexts""" + + def __init__(self, embeddings_size, decoder_size, attention_size, output_size): + super(BothContextGate, self).__init__() + self.context_gate = ContextGate(embeddings_size, decoder_size, attention_size, output_size) + self.tanh = nn.Tanh() + + def forward(self, prev_emb, dec_state, attn_state): + z, source, target = self.context_gate(prev_emb, dec_state, attn_state) + return self.tanh((1.0 - z) * target + z * source) diff --git a/onmt/modules/global_attention.py b/onmt/modules/global_attention.py new file mode 100644 index 00000000..da032c94 --- /dev/null +++ b/onmt/modules/global_attention.py @@ -0,0 +1,225 @@ +"""Global attention modules (Luong / Bahdanau)""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from onmt.modules.sparse_activations import sparsemax +from onmt.utils.misc import aeq, sequence_mask + +# This class is mainly used by decoder.py for RNNs but also +# by the CNN / transformer decoder when copy attention is used +# CNN has its own attention mechanism ConvMultiStepAttention +# Transformer has its own MultiHeadedAttention + + +class GlobalAttention(nn.Module): + r""" + Global attention takes a matrix and a query vector. It + then computes a parameterized convex combination of the matrix + based on the input query. + + Constructs a unit mapping a query `q` of size `dim` + and a source matrix `H` of size `n x dim`, to an output + of size `dim`. + + + .. mermaid:: + + graph BT + A[Query] + subgraph RNN + C[H 1] + D[H 2] + E[H N] + end + F[Attn] + G[Output] + A --> F + C --> F + D --> F + E --> F + C -.-> G + D -.-> G + E -.-> G + F --> G + + All models compute the output as + :math:`c = \sum_{j=1}^{\text{SeqLength}} a_j H_j` where + :math:`a_j` is the softmax of a score function. + Then then apply a projection layer to [q, c]. + + However they + differ on how they compute the attention score. + + * Luong Attention (dot, general): + * dot: :math:`\text{score}(H_j,q) = H_j^T q` + * general: :math:`\text{score}(H_j, q) = H_j^T W_a q` + + + * Bahdanau Attention (mlp): + * :math:`\text{score}(H_j, q) = v_a^T \text{tanh}(W_a q + U_a h_j)` + + + Args: + dim (int): dimensionality of query and key + coverage (bool): use coverage term + attn_type (str): type of attention to use, options [dot,general,mlp] + attn_func (str): attention function to use, options [softmax,sparsemax] + + """ + + def __init__(self, dim, coverage=False, attn_type="dot", attn_func="softmax"): + super(GlobalAttention, self).__init__() + + self.dim = dim + assert attn_type in ["dot", "general", "mlp"], "Please select a valid attention type (got {:s}).".format( + attn_type + ) + self.attn_type = attn_type + assert attn_func in ["softmax", "sparsemax"], "Please select a valid attention function." + self.attn_func = attn_func + + if self.attn_type == "general": + self.linear_in = nn.Linear(dim, dim, bias=False) + elif self.attn_type == "mlp": + self.linear_context = nn.Linear(dim, dim, bias=False) + self.linear_query = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, 1, bias=False) + # mlp wants it with bias + out_bias = self.attn_type == "mlp" + self.linear_out = nn.Linear(dim * 2, dim, bias=out_bias) + + if coverage: + self.linear_cover = nn.Linear(1, dim, bias=False) + + def score(self, h_t, h_s): + """ + Args: + h_t (FloatTensor): sequence of queries ``(batch, tgt_len, dim)`` + h_s (FloatTensor): sequence of sources ``(batch, src_len, dim`` + + Returns: + FloatTensor: raw attention scores (unnormalized) for each src index + ``(batch, tgt_len, src_len)`` + """ + + # Check input sizes + src_batch, src_len, src_dim = h_s.size() + tgt_batch, tgt_len, tgt_dim = h_t.size() + aeq(src_batch, tgt_batch) + aeq(src_dim, tgt_dim) + aeq(self.dim, src_dim) + + if self.attn_type in ["general", "dot"]: + if self.attn_type == "general": + h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) + h_t_ = self.linear_in(h_t_) + h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) + h_s_ = h_s.transpose(1, 2) + # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) + return torch.bmm(h_t, h_s_) + else: + dim = self.dim + wq = self.linear_query(h_t.view(-1, dim)) + wq = wq.view(tgt_batch, tgt_len, 1, dim) + wq = wq.expand(tgt_batch, tgt_len, src_len, dim) + + uh = self.linear_context(h_s.contiguous().view(-1, dim)) + uh = uh.view(src_batch, 1, src_len, dim) + uh = uh.expand(src_batch, tgt_len, src_len, dim) + + # (batch, t_len, s_len, d) + wquh = torch.tanh(wq + uh) + + return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len) + + def forward(self, source, memory_bank, memory_lengths=None, coverage=None): + """ + + Args: + source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` + memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` + memory_lengths (LongTensor): the source context lengths ``(batch,)`` + coverage (FloatTensor): None (not supported yet) + + Returns: + (FloatTensor, FloatTensor): + + * Computed vector ``(tgt_len, batch, dim)`` + * Attention distribtutions for each query + ``(tgt_len, batch, src_len)`` + """ + + # one step input + if source.dim() == 2: + one_step = True + source = source.unsqueeze(1) + else: + one_step = False + + batch, source_l, dim = memory_bank.size() + batch_, target_l, dim_ = source.size() + aeq(batch, batch_) + aeq(dim, dim_) + aeq(self.dim, dim) + if coverage is not None: + batch_, source_l_ = coverage.size() + aeq(batch, batch_) + aeq(source_l, source_l_) + + if coverage is not None: + cover = coverage.view(-1).unsqueeze(1) + memory_bank += self.linear_cover(cover).view_as(memory_bank) + memory_bank = torch.tanh(memory_bank) + + # compute attention scores, as in Luong et al. + align = self.score(source, memory_bank) + + if memory_lengths is not None: + mask = sequence_mask(memory_lengths, max_len=align.size(-1)) + mask = mask.unsqueeze(1) # Make it broadcastable. + align.masked_fill_(~mask, -float('inf')) + + # Softmax or sparsemax to normalize attention weights + if self.attn_func == "softmax": + align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) + else: + align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) + align_vectors = align_vectors.view(batch, target_l, source_l) + + # each context vector c_t is the weighted average + # over all the source hidden states + c = torch.bmm(align_vectors, memory_bank) + + # concatenate + concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) + attn_h = self.linear_out(concat_c).view(batch, target_l, dim) + if self.attn_type in ["general", "dot"]: + attn_h = torch.tanh(attn_h) + + if one_step: + attn_h = attn_h.squeeze(1) + align_vectors = align_vectors.squeeze(1) + + # Check output sizes + batch_, dim_ = attn_h.size() + aeq(batch, batch_) + aeq(dim, dim_) + batch_, source_l_ = align_vectors.size() + aeq(batch, batch_) + aeq(source_l, source_l_) + + else: + attn_h = attn_h.transpose(0, 1).contiguous() + align_vectors = align_vectors.transpose(0, 1).contiguous() + # Check output sizes + target_l_, batch_, dim_ = attn_h.size() + aeq(target_l, target_l_) + aeq(batch, batch_) + aeq(dim, dim_) + target_l_, batch_, source_l_ = align_vectors.size() + aeq(target_l, target_l_) + aeq(batch, batch_) + aeq(source_l, source_l_) + + return attn_h, align_vectors diff --git a/mammoth/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py similarity index 98% rename from mammoth/modules/multi_headed_attn.py rename to onmt/modules/multi_headed_attn.py index 1a4b3028..d44912b1 100644 --- a/mammoth/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -3,9 +3,9 @@ import torch import torch.nn as nn -from mammoth.utils.misc import generate_relative_positions_matrix, relative_matmul +from onmt.utils.misc import generate_relative_positions_matrix, relative_matmul -# from mammoth.utils.misc import aeq +# from onmt.utils.misc import aeq class MultiHeadedAttention(nn.Module): diff --git a/mammoth/modules/position_ffn.py b/onmt/modules/position_ffn.py similarity index 100% rename from mammoth/modules/position_ffn.py rename to onmt/modules/position_ffn.py diff --git a/onmt/modules/sparse_activations.py b/onmt/modules/sparse_activations.py new file mode 100644 index 00000000..7a5e8a75 --- /dev/null +++ b/onmt/modules/sparse_activations.py @@ -0,0 +1,97 @@ +""" +An implementation of sparsemax (Martins & Astudillo, 2016). See +:cite:`DBLP:journals/corr/MartinsA16` for detailed description. + +By Ben Peters and Vlad Niculae +""" + +import torch +from torch.autograd import Function +from torch.cuda.amp import custom_fwd, custom_bwd +import torch.nn as nn + + +def _make_ix_like(input, dim=0): + d = input.size(dim) + rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) + view = [1] * input.dim() + view[0] = -1 + return rho.view(view).transpose(0, dim) + + +def _threshold_and_support(input, dim=0): + """Sparsemax building block: compute the threshold + + Args: + input: any dimension + dim: dimension along which to apply the sparsemax + + Returns: + the threshold value + """ + + input_srt, _ = torch.sort(input, descending=True, dim=dim) + input_cumsum = input_srt.cumsum(dim) - 1 + rhos = _make_ix_like(input, dim) + support = rhos * input_srt > input_cumsum + + support_size = support.sum(dim=dim).unsqueeze(dim) + tau = input_cumsum.gather(dim, support_size - 1) + tau /= support_size.to(input.dtype) + return tau, support_size + + +class SparsemaxFunction(Function): + @staticmethod + @custom_fwd + def forward(ctx, input, dim=0): + """sparsemax: normalizing sparse transform (a la softmax) + + Parameters: + input (Tensor): any shape + dim: dimension along which to apply sparsemax + + Returns: + output (Tensor): same shape as input + """ + ctx.dim = dim + max_val, _ = input.max(dim=dim, keepdim=True) + input -= max_val # same numerical stability trick as for softmax + tau, supp_size = _threshold_and_support(input, dim=dim) + output = torch.clamp(input - tau, min=0) + ctx.save_for_backward(supp_size, output) + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + supp_size, output = ctx.saved_tensors + dim = ctx.dim + grad_input = grad_output.clone() + grad_input[output == 0] = 0 + + v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() + v_hat = v_hat.unsqueeze(dim) + grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) + return grad_input, None + + +sparsemax = SparsemaxFunction.apply + + +class Sparsemax(nn.Module): + def __init__(self, dim=0): + self.dim = dim + super(Sparsemax, self).__init__() + + def forward(self, input): + return sparsemax(input, self.dim) + + +class LogSparsemax(nn.Module): + def __init__(self, dim=0): + self.dim = dim + super(LogSparsemax, self).__init__() + + def forward(self, input): + return torch.log(sparsemax(input, self.dim)) diff --git a/onmt/modules/sparse_losses.py b/onmt/modules/sparse_losses.py new file mode 100644 index 00000000..3e67c885 --- /dev/null +++ b/onmt/modules/sparse_losses.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.cuda.amp import custom_fwd, custom_bwd +from onmt.modules.sparse_activations import _threshold_and_support +from onmt.utils.misc import aeq + + +class SparsemaxLossFunction(Function): + @staticmethod + @custom_fwd + def forward(ctx, input, target): + """ + input (FloatTensor): ``(n, num_classes)``. + target (LongTensor): ``(n,)``, the indices of the target classes + """ + input_batch, classes = input.size() + target_batch = target.size(0) + aeq(input_batch, target_batch) + + z_k = input.gather(1, target.unsqueeze(1)).squeeze() + tau_z, support_size = _threshold_and_support(input, dim=1) + support = input > tau_z + x = torch.where(support, input**2 - tau_z**2, torch.tensor(0.0, device=input.device)).sum(dim=1) + ctx.save_for_backward(input, target, tau_z) + # clamping necessary because of numerical errors: loss should be lower + # bounded by zero, but negative values near zero are possible without + # the clamp + return torch.clamp(x / 2 - z_k + 0.5, min=0.0) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + input, target, tau_z = ctx.saved_tensors + sparsemax_out = torch.clamp(input - tau_z, min=0) + delta = torch.zeros_like(sparsemax_out) + delta.scatter_(1, target.unsqueeze(1), 1) + return sparsemax_out - delta, None + + +sparsemax_loss = SparsemaxLossFunction.apply + + +class SparsemaxLoss(nn.Module): + """ + An implementation of sparsemax loss, first proposed in + :cite:`DBLP:journals/corr/MartinsA16`. If using + a sparse output layer, it is not possible to use negative log likelihood + because the loss is infinite in the case the target is assigned zero + probability. Inputs to SparsemaxLoss are arbitrary dense real-valued + vectors (like in nn.CrossEntropyLoss), not probability vectors (like in + nn.NLLLoss). + """ + + def __init__(self, weight=None, ignore_index=-100, reduction='elementwise_mean'): + assert reduction in ['elementwise_mean', 'sum', 'none'] + self.reduction = reduction + self.weight = weight + self.ignore_index = ignore_index + super(SparsemaxLoss, self).__init__() + + def forward(self, input, target): + loss = sparsemax_loss(input, target) + if self.ignore_index >= 0: + ignored_positions = target == self.ignore_index + size = float((target.size(0) - ignored_positions.sum()).item()) + loss.masked_fill_(ignored_positions, 0.0) + else: + size = float(target.size(0)) + if self.reduction == 'sum': + loss = loss.sum() + elif self.reduction == 'elementwise_mean': + loss = loss.sum() / size + return loss diff --git a/onmt/modules/stable_embeddings.py b/onmt/modules/stable_embeddings.py new file mode 100644 index 00000000..8111cc76 --- /dev/null +++ b/onmt/modules/stable_embeddings.py @@ -0,0 +1,95 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from typing import Optional + +from torch import Tensor +import torch.nn.functional as F + +# from bitsandbytes.optim import GlobalOptimManager + + +class StableEmbedding(torch.nn.Embedding): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + ) -> None: + super(StableEmbedding, self).__init__( + num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight + ) + self.norm = torch.nn.LayerNorm(embedding_dim, eps=1e-6) + # GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) + + def reset_parameters(self) -> None: + torch.nn.init.xavier_uniform_(self.weight) + self._fill_padding_idx_with_zero() + + ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + to make the Layer compatible with Pytorch < 1.9. + This means that if this changes in future PyTorch releases this need to change too + which is cumbersome. However, with this we can ensure compatibility with previous + PyTorch releases. + ''' + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + emb = F.embedding( + input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse + ) + + return self.norm(emb) + + +class Embedding(torch.nn.Embedding): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + ) -> None: + super(Embedding, self).__init__( + num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight + ) + # GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) + + def reset_parameters(self) -> None: + torch.nn.init.xavier_uniform_(self.weight) + self._fill_padding_idx_with_zero() + + ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + to make the Layer compatible with Pytorch < 1.9. + This means that if this changes in future PyTorch releases this need to change too + which is cumbersome. However, with this we can ensure compatibility with previous + PyTorch releases. + ''' + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + emb = F.embedding( + input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse + ) + + return emb diff --git a/onmt/modules/structured_attention.py b/onmt/modules/structured_attention.py new file mode 100644 index 00000000..59206b61 --- /dev/null +++ b/onmt/modules/structured_attention.py @@ -0,0 +1,35 @@ +import torch.nn as nn +import torch +import torch.cuda + + +class MatrixTree(nn.Module): + """Implementation of the matrix-tree theorem for computing marginals + of non-projective dependency parsing. This attention layer is used + in the paper "Learning Structured Text Representations" + :cite:`DBLP:journals/corr/LiuL17d`. + """ + + def __init__(self, eps=1e-5): + self.eps = eps + super(MatrixTree, self).__init__() + + def forward(self, input): + laplacian = input.exp() + self.eps + output = input.clone() + for b in range(input.size(0)): + lap = laplacian[b].masked_fill(torch.eye(input.size(1), device=input.device).ne(0), 0) + lap = -lap + torch.diag(lap.sum(0)) + # store roots on diagonal + lap[0] = input[b].diag().exp() + inv_laplacian = lap.inverse() + + factor = inv_laplacian.diag().unsqueeze(1).expand_as(input[b]).transpose(0, 1) + term1 = input[b].exp().mul(factor).clone() + term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone() + term1[:, 0] = 0 + term2[0] = 0 + output[b] = term1 - term2 + roots_output = input[b].diag().exp().mul(inv_laplacian.transpose(0, 1)[0]) + output[b] = output[b] + torch.diag(roots_output) + return output diff --git a/mammoth/modules/util_class.py b/onmt/modules/util_class.py similarity index 100% rename from mammoth/modules/util_class.py rename to onmt/modules/util_class.py diff --git a/onmt/modules/weight_norm.py b/onmt/modules/weight_norm.py new file mode 100644 index 00000000..723a7d74 --- /dev/null +++ b/onmt/modules/weight_norm.py @@ -0,0 +1,224 @@ +""" Weights normalization modules """ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter + + +def get_var_maybe_avg(namespace, var_name, training, polyak_decay): + """utility for retrieving polyak averaged params + Update average + """ + v = getattr(namespace, var_name) + v_avg = getattr(namespace, var_name + '_avg') + v_avg -= (1 - polyak_decay) * (v_avg - v.data) + + if training: + return v + else: + return v_avg + + +def get_vars_maybe_avg(namespace, var_names, training, polyak_decay): + """utility for retrieving polyak averaged params""" + vars = [] + for vn in var_names: + vars.append(get_var_maybe_avg(namespace, vn, training, polyak_decay)) + return vars + + +class WeightNormLinear(nn.Linear): + """ + Implementation of "Weight Normalization: A Simple Reparameterization + to Accelerate Training of Deep Neural Networks" + :cite:`DBLP:journals/corr/SalimansK16` + + As a reparameterization method, weight normalization is same + as BatchNormalization, but it doesn't depend on minibatch. + + NOTE: This is used nowhere in the code at this stage + Vincent Nguyen 05/18/2018 + """ + + def __init__(self, in_features, out_features, init_scale=1.0, polyak_decay=0.9995): + super(WeightNormLinear, self).__init__(in_features, out_features, bias=True) + + self.V = self.weight + self.g = Parameter(torch.Tensor(out_features)) + self.b = self.bias + + self.register_buffer('V_avg', torch.zeros(out_features, in_features)) + self.register_buffer('g_avg', torch.zeros(out_features)) + self.register_buffer('b_avg', torch.zeros(out_features)) + + self.init_scale = init_scale + self.polyak_decay = polyak_decay + self.reset_parameters() + + def reset_parameters(self): + return + + def forward(self, x, init=False): + if init is True: + # out_features * in_features + self.V.data.copy_(torch.randn(self.V.data.size()).type_as(self.V.data) * 0.05) + # norm is out_features * 1 + v_norm = self.V.data / self.V.data.norm(2, 1).expand_as(self.V.data) + # batch_size * out_features + x_init = F.linear(x, v_norm).data + # out_features + m_init, v_init = x_init.mean(0).squeeze(0), x_init.var(0).squeeze(0) + # out_features + scale_init = self.init_scale / torch.sqrt(v_init + 1e-10) + self.g.data.copy_(scale_init) + self.b.data.copy_(-m_init * scale_init) + x_init = scale_init.view(1, -1).expand_as(x_init) * (x_init - m_init.view(1, -1).expand_as(x_init)) + self.V_avg.copy_(self.V.data) + self.g_avg.copy_(self.g.data) + self.b_avg.copy_(self.b.data) + return x_init + else: + v, g, b = get_vars_maybe_avg(self, ['V', 'g', 'b'], self.training, polyak_decay=self.polyak_decay) + # batch_size * out_features + x = F.linear(x, v) + scalar = g / torch.norm(v, 2, 1).squeeze(1) + x = scalar.view(1, -1).expand_as(x) * x + b.view(1, -1).expand_as(x) + return x + + +class WeightNormConv2d(nn.Conv2d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + init_scale=1.0, + polyak_decay=0.9995, + ): + super(WeightNormConv2d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, groups + ) + + self.V = self.weight + self.g = Parameter(torch.Tensor(out_channels)) + self.b = self.bias + + self.register_buffer('V_avg', torch.zeros(self.V.size())) + self.register_buffer('g_avg', torch.zeros(out_channels)) + self.register_buffer('b_avg', torch.zeros(out_channels)) + + self.init_scale = init_scale + self.polyak_decay = polyak_decay + self.reset_parameters() + + def reset_parameters(self): + return + + def forward(self, x, init=False): + if init is True: + # out_channels, in_channels // groups, * kernel_size + self.V.data.copy_(torch.randn(self.V.data.size()).type_as(self.V.data) * 0.05) + v_norm = self.V.data / self.V.data.view(self.out_channels, -1).norm(2, 1).view( + self.out_channels, *([1] * (len(self.kernel_size) + 1)) + ).expand_as(self.V.data) + x_init = F.conv2d(x, v_norm, None, self.stride, self.padding, self.dilation, self.groups).data + t_x_init = x_init.transpose(0, 1).contiguous().view(self.out_channels, -1) + m_init, v_init = t_x_init.mean(1).squeeze(1), t_x_init.var(1).squeeze(1) + # out_features + scale_init = self.init_scale / torch.sqrt(v_init + 1e-10) + self.g.data.copy_(scale_init) + self.b.data.copy_(-m_init * scale_init) + scale_init_shape = scale_init.view(1, self.out_channels, *([1] * (len(x_init.size()) - 2))) + m_init_shape = m_init.view(1, self.out_channels, *([1] * (len(x_init.size()) - 2))) + x_init = scale_init_shape.expand_as(x_init) * (x_init - m_init_shape.expand_as(x_init)) + self.V_avg.copy_(self.V.data) + self.g_avg.copy_(self.g.data) + self.b_avg.copy_(self.b.data) + return x_init + else: + v, g, b = get_vars_maybe_avg(self, ['V', 'g', 'b'], self.training, polyak_decay=self.polyak_decay) + + scalar = torch.norm(v.view(self.out_channels, -1), 2, 1) + if len(scalar.size()) == 2: + scalar = g / scalar.squeeze(1) + else: + scalar = g / scalar + + w = scalar.view(self.out_channels, *([1] * (len(v.size()) - 1))).expand_as(v) * v + + x = F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups) + return x + + +# This is used nowhere in the code at the moment (Vincent Nguyen 05/18/2018) + + +class WeightNormConvTranspose2d(nn.ConvTranspose2d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + init_scale=1.0, + polyak_decay=0.9995, + ): + super(WeightNormConvTranspose2d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, output_padding, groups + ) + # in_channels, out_channels, *kernel_size + self.V = self.weight + self.g = Parameter(torch.Tensor(out_channels)) + self.b = self.bias + + self.register_buffer('V_avg', torch.zeros(self.V.size())) + self.register_buffer('g_avg', torch.zeros(out_channels)) + self.register_buffer('b_avg', torch.zeros(out_channels)) + + self.init_scale = init_scale + self.polyak_decay = polyak_decay + self.reset_parameters() + + def reset_parameters(self): + return + + def forward(self, x, init=False): + if init is True: + # in_channels, out_channels, *kernel_size + self.V.data.copy_(torch.randn(self.V.data.size()).type_as(self.V.data) * 0.05) + v_norm = self.V.data / self.V.data.transpose(0, 1).contiguous().view(self.out_channels, -1).norm(2, 1).view( + self.in_channels, self.out_channels, *([1] * len(self.kernel_size)) + ).expand_as(self.V.data) + x_init = F.conv_transpose2d( + x, v_norm, None, self.stride, self.padding, self.output_padding, self.groups + ).data + # self.out_channels, 1 + t_x_init = x_init.tranpose(0, 1).contiguous().view(self.out_channels, -1) + # out_features + m_init, v_init = t_x_init.mean(1).squeeze(1), t_x_init.var(1).squeeze(1) + # out_features + scale_init = self.init_scale / torch.sqrt(v_init + 1e-10) + self.g.data.copy_(scale_init) + self.b.data.copy_(-m_init * scale_init) + scale_init_shape = scale_init.view(1, self.out_channels, *([1] * (len(x_init.size()) - 2))) + m_init_shape = m_init.view(1, self.out_channels, *([1] * (len(x_init.size()) - 2))) + + x_init = scale_init_shape.expand_as(x_init) * (x_init - m_init_shape.expand_as(x_init)) + self.V_avg.copy_(self.V.data) + self.g_avg.copy_(self.g.data) + self.b_avg.copy_(self.b.data) + return x_init + else: + v, g, b = get_vars_maybe_avg(self, ['V', 'g', 'b'], self.training, polyak_decay=self.polyak_decay) + scalar = g / torch.norm(v.transpose(0, 1).contiguous().view(self.out_channels, -1), 2, 1).squeeze(1) + w = scalar.view(self.in_channels, self.out_channels, *([1] * (len(v.size()) - 2))).expand_as(v) * v + + x = F.conv_transpose2d(x, w, b, self.stride, self.padding, self.output_padding, self.groups) + return x diff --git a/mammoth/opts.py b/onmt/opts.py similarity index 93% rename from mammoth/opts.py rename to onmt/opts.py index 22a3dfd6..695f8977 100644 --- a/mammoth/opts.py +++ b/onmt/opts.py @@ -1,11 +1,12 @@ """ Implementation of all available options """ import configargparse -from mammoth.constants import ModelTask -from mammoth.modules.position_ffn import ACTIVATION_FUNCTIONS -from mammoth.modules.position_ffn import ActivationFunction -from mammoth.transforms import AVAILABLE_TRANSFORMS -from mammoth.distributed import TASK_DISTRIBUTION_STRATEGIES +from onmt.constants import ModelTask +from onmt.models.sru import CheckSRU +from onmt.modules.position_ffn import ACTIVATION_FUNCTIONS +from onmt.modules.position_ffn import ActivationFunction +from onmt.transforms import AVAILABLE_TRANSFORMS +from onmt.utils.distributed import TASK_DISTRIBUTION_STRATEGIES def config_opts(parser): @@ -55,7 +56,7 @@ def _add_logging_opts(parser, is_train=True): "--tensorboard_log_dir", "-tensorboard_log_dir", type=str, - default="runs/mammoth", + default="runs/onmt", help="Log directory for Tensorboard. This is also the name of the run.", ) group.add( @@ -93,10 +94,10 @@ def _add_reproducibility_opts(parser): def _add_dynamic_corpus_opts(parser, build_vocab_only=False): """Options related to training corpus, type: a list of dictionary.""" - group = parser.add_argument_group('Data/Tasks') + group = parser.add_argument_group('Data') group.add( - "-tasks", - "--tasks", + "-data", + "--data", required=True, help="List of datasets and their specifications. See examples/*.yaml for further details.", ) @@ -273,7 +274,7 @@ def _add_dynamic_transform_opts(parser): """Options related to transforms. Options that specified in the definitions of each transform class - at `mammoth/transforms/*.py`. + at `onmt/transforms/*.py`. """ for name, transform_cls in AVAILABLE_TRANSFORMS.items(): transform_cls.add_options(parser) @@ -284,7 +285,7 @@ def dynamic_prepare_opts(parser, build_vocab_only=False): Add all dynamic data prepare related options to parser. If `build_vocab_only` set to True, then only contains options that - will be used in `mammoth/bin/build_vocab.py`. + will be used in `onmt/bin/build_vocab.py`. """ config_opts(parser) _add_dynamic_corpus_opts(parser, build_vocab_only=build_vocab_only) @@ -304,6 +305,9 @@ def model_opts(parser): # Embedding Options group = parser.add_argument_group('Model-Embeddings') + group.add('--src_word_vec_size', '-src_word_vec_size', type=int, default=500, help='Word embedding size for src.') + group.add('--tgt_word_vec_size', '-tgt_word_vec_size', type=int, default=500, help='Word embedding size for tgt.') + group.add('--word_vec_size', '-word_vec_size', type=int, default=-1, help='Word embedding size for src and tgt.') group.add( '--share_decoder_embeddings', @@ -382,41 +386,42 @@ def model_opts(parser): '--encoder_type', '-encoder_type', type=str, - default='transformer', - choices=['mean', 'transformer'], + default='rnn', + choices=['rnn', 'brnn', 'ggnn', 'mean', 'transformer', 'cnn', 'transformer_lm'], help="Type of encoder layer to use. Non-RNN layers " "are experimental. Options are " - "[mean|transformer].", + "[rnn|brnn|ggnn|mean|transformer|cnn|transformer_lm].", ) group.add( '--decoder_type', '-decoder_type', type=str, - default='transformer', - choices=['transformer'], + default='rnn', + choices=['rnn', 'transformer', 'cnn', 'transformer_lm'], help="Type of decoder layer to use. Non-RNN layers " "are experimental. Options are " - "[transformer].", + "[rnn|transformer|cnn|transformer].", ) group.add('--layers', '-layers', type=int, default=-1, help='Deprecated') group.add('--enc_layers', '-enc_layers', nargs='+', type=int, help='Number of layers in each encoder') group.add('--dec_layers', '-dec_layers', nargs='+', type=int, help='Number of layers in each decoder') group.add( - '--model_dim', - '-model_dim', + '--rnn_size', + '-rnn_size', type=int, default=-1, - help="Size of rnn hidden states.", + help="Size of rnn hidden states. Overwrites enc_rnn_size and dec_rnn_size", + ) + group.add('--enc_rnn_size', '-enc_rnn_size', type=int, default=500, help="Size of encoder rnn hidden states.") + group.add('--dec_rnn_size', '-dec_rnn_size', type=int, default=500, help="Size of decoder rnn hidden states.") + group.add( + '--cnn_kernel_width', + '-cnn_kernel_width', + type=int, + default=3, + help="Size of windows in the cnn, the kernel_size is (cnn_kernel_width, 1) in conv layer", ) - - # group.add( - # '--cnn_kernel_width', - # '-cnn_kernel_width', - # type=int, - # default=3, - # help="Size of windows in the cnn, the kernel_size is (cnn_kernel_width, 1) in conv layer", - # ) group.add( '--pos_ffn_activation_fn', @@ -430,32 +435,43 @@ def model_opts(parser): f' {ActivationFunction.relu}.', ) - # group.add( - # '--input_feed', - # '-input_feed', - # type=int, - # default=1, - # help="Feed the context vector at each time step as " - # "additional input (via concatenation with the word " - # "embeddings) to the decoder.", - # ) + group.add( + '--input_feed', + '-input_feed', + type=int, + default=1, + help="Feed the context vector at each time step as " + "additional input (via concatenation with the word " + "embeddings) to the decoder.", + ) group.add( '--bridge', '-bridge', action="store_true", help="Have an additional layer between the last encoder state and the first decoder state", ) + group.add( + '--rnn_type', + '-rnn_type', + type=str, + default='LSTM', + choices=['LSTM', 'GRU', 'SRU'], + action=CheckSRU, + help="The gate type to use in the RNNs", + ) # group.add('--residual', '-residual', action="store_true", # help="Add residual connections between RNN layers.") - # group.add( - # '--context_gate', - # '-context_gate', - # type=str, - # default=None, - # choices=['source', 'target', 'both'], - # help="Type of context gate to use. Do not select for no context gate.", - # ) + group.add('--brnn', '-brnn', action=DeprecateAction, help="Deprecated, use `encoder_type`.") + + group.add( + '--context_gate', + '-context_gate', + type=str, + default=None, + choices=['source', 'target', 'both'], + help="Type of context gate to use. Do not select for no context gate.", + ) # The following options (bridge_extra_node to n_steps) are used # for training with --encoder_type ggnn (Gated Graph Neural Network). @@ -498,7 +514,7 @@ def model_opts(parser): '-global_attention_function', type=str, default="softmax", - choices=["softmax"], + choices=["softmax", "sparsemax"], ) group.add( '--self_attn_type', @@ -566,10 +582,10 @@ def model_opts(parser): '--generator_function', '-generator_function', default="softmax", - choices=["softmax"], + choices=["softmax", "sparsemax"], help="Which function to use for generating " "probabilities over the target vocabulary (choices: " - "softmax)", + "softmax, sparsemax)", ) group.add('--copy_attn_force', '-copy_attn_force', action="store_true", help='When available, train to copy.') group.add('--reuse_copy_attn', '-reuse_copy_attn', action="store_true", help="Reuse standard attention for copy") @@ -600,7 +616,7 @@ def model_opts(parser): type=str, default="O1", choices=["O0", "O1", "O2", "O3"], - help="For FP16 training, the opt_level to use. See https://nvidia.github.io/apex/amp.html#opts-levels.", + help="For FP16 training, the opt_level to use. See https://nvidia.github.io/apex/amp.html#opt-levels.", ) # attention bridge options @@ -841,7 +857,7 @@ def _add_train_general_opts(parser): '--optim', '-optim', default='sgd', - choices=['sgd', 'adagrad', 'adadelta', 'adam', 'adamw', 'adafactor', 'fusedadam'], + choices=['sgd', 'adagrad', 'adadelta', 'adam', 'adamw', 'sparseadam', 'adafactor', 'fusedadam'], help="Optimization method.", ) group.add( @@ -993,8 +1009,8 @@ def _add_train_general_opts(parser): def _add_train_dynamic_data(parser): group = parser.add_argument_group("Dynamic data") group.add( - "-pool_size", - "--pool_size", + "-bucket_size", + "--bucket_size", type=int, default=2048, help="Number of examples to dynamically pool before batching.", @@ -1204,7 +1220,7 @@ def translate_opts(parser, dynamic=False): help="Divide src and tgt (if applicable) into " "smaller multiple src and tgt files, then " "build shards, each shard will have " - "opts.shard_size samples except last shard. " + "opt.shard_size samples except last shard. " "shard_size=0 means no segmentation " "shard_size>0 means segment dataset into multiple shards, " "each shard has shard_size samples", diff --git a/mammoth/rmsnorm_torch.py b/onmt/rmsnorm_torch.py similarity index 100% rename from mammoth/rmsnorm_torch.py rename to onmt/rmsnorm_torch.py diff --git a/mammoth/tests/__init__.py b/onmt/tests/__init__.py similarity index 100% rename from mammoth/tests/__init__.py rename to onmt/tests/__init__.py diff --git a/mammoth/tests/output_hyp.txt b/onmt/tests/output_hyp.txt similarity index 100% rename from mammoth/tests/output_hyp.txt rename to onmt/tests/output_hyp.txt diff --git a/mammoth/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh similarity index 100% rename from mammoth/tests/pull_request_chk.sh rename to onmt/tests/pull_request_chk.sh diff --git a/mammoth/tests/rebuild_test_models.sh b/onmt/tests/rebuild_test_models.sh similarity index 100% rename from mammoth/tests/rebuild_test_models.sh rename to onmt/tests/rebuild_test_models.sh diff --git a/mammoth/tests/sample_glove.txt b/onmt/tests/sample_glove.txt similarity index 100% rename from mammoth/tests/sample_glove.txt rename to onmt/tests/sample_glove.txt diff --git a/onmt/tests/test_attention.py b/onmt/tests/test_attention.py new file mode 100644 index 00000000..acffac3e --- /dev/null +++ b/onmt/tests/test_attention.py @@ -0,0 +1,33 @@ +""" +Here come the tests for attention types and their compatibility +""" +import unittest +import torch +from torch.autograd import Variable + +import onmt + + +class TestAttention(unittest.TestCase): + def test_masked_global_attention(self): + + source_lengths = torch.IntTensor([7, 3, 5, 2]) + # illegal_weights_mask = torch.ByteTensor([ + # [0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 1, 1, 1, 1], + # [0, 0, 0, 0, 0, 1, 1], + # [0, 0, 1, 1, 1, 1, 1]]) + + batch_size = source_lengths.size(0) + dim = 20 + + memory_bank = Variable(torch.randn(batch_size, source_lengths.max(), dim)) + hidden = Variable(torch.randn(batch_size, dim)) + + attn = onmt.modules.GlobalAttention(dim) + + _, alignments = attn(hidden, memory_bank, memory_lengths=source_lengths) + # TODO: fix for pytorch 0.3 + # illegal_weights = alignments.masked_select(illegal_weights_mask) + + # self.assertEqual(0.0, illegal_weights.data.sum()) diff --git a/mammoth/tests/test_beam_search.py b/onmt/tests/test_beam_search.py similarity index 99% rename from mammoth/tests/test_beam_search.py rename to onmt/tests/test_beam_search.py index f43dd134..8a6e54a4 100644 --- a/mammoth/tests/test_beam_search.py +++ b/onmt/tests/test_beam_search.py @@ -1,6 +1,6 @@ import unittest -from mammoth.translate.beam_search import BeamSearch, GNMTGlobalScorer -from mammoth.translate.beam_search import BeamSearchLM +from onmt.translate.beam_search import BeamSearch, GNMTGlobalScorer +from onmt.translate.beam_search import BeamSearchLM from copy import deepcopy diff --git a/onmt/tests/test_copy_generator.py b/onmt/tests/test_copy_generator.py new file mode 100644 index 00000000..4b3291fa --- /dev/null +++ b/onmt/tests/test_copy_generator.py @@ -0,0 +1,113 @@ +import unittest +from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss + +import itertools +from copy import deepcopy + +import torch +from torch.nn.functional import softmax + +from onmt.tests.utils_for_tests import product_dict + + +class TestCopyGenerator(unittest.TestCase): + INIT_CASES = list( + product_dict( + input_size=[172], + output_size=[319], + pad_idx=[0, 39], + ) + ) + PARAMS = list(product_dict(batch_size=[1, 14], max_seq_len=[23], tgt_max_len=[50], n_extra_words=[107])) + + @classmethod + def dummy_inputs(cls, params, init_case): + hidden = torch.randn((params["batch_size"] * params["tgt_max_len"], init_case["input_size"])) + attn = torch.randn((params["batch_size"] * params["tgt_max_len"], params["max_seq_len"])) + src_map = torch.randn((params["max_seq_len"], params["batch_size"], params["n_extra_words"])) + return hidden, attn, src_map + + @classmethod + def expected_shape(cls, params, init_case): + return params["tgt_max_len"] * params["batch_size"], init_case["output_size"] + params["n_extra_words"] + + def test_copy_gen_forward_shape(self): + for params, init_case in itertools.product(self.PARAMS, self.INIT_CASES): + cgen = CopyGenerator(**init_case) + dummy_in = self.dummy_inputs(params, init_case) + res = cgen(*dummy_in) + expected_shape = self.expected_shape(params, init_case) + self.assertEqual(res.shape, expected_shape, init_case.__str__()) + + def test_copy_gen_outp_has_no_prob_of_pad(self): + for params, init_case in itertools.product(self.PARAMS, self.INIT_CASES): + cgen = CopyGenerator(**init_case) + dummy_in = self.dummy_inputs(params, init_case) + res = cgen(*dummy_in) + self.assertTrue(res[:, init_case["pad_idx"]].allclose(torch.tensor(0.0))) + + def test_copy_gen_trainable_params_update(self): + for params, init_case in itertools.product(self.PARAMS, self.INIT_CASES): + cgen = CopyGenerator(**init_case) + trainable_params = {n: p for n, p in cgen.named_parameters() if p.requires_grad} + assert len(trainable_params) > 0 # sanity check + old_weights = deepcopy(trainable_params) + dummy_in = self.dummy_inputs(params, init_case) + res = cgen(*dummy_in) + pretend_loss = res.sum() + pretend_loss.backward() + dummy_optim = torch.optim.SGD(trainable_params.values(), 1) + dummy_optim.step() + for param_name in old_weights.keys(): + self.assertTrue( + trainable_params[param_name].ne(old_weights[param_name]).any(), + param_name + " " + init_case.__str__(), + ) + + +class TestCopyGeneratorLoss(unittest.TestCase): + INIT_CASES = list( + product_dict(vocab_size=[172], unk_index=[0, 39], ignore_index=[1, 17], force_copy=[True, False]) # pad idx + ) + PARAMS = list(product_dict(batch_size=[1, 14], tgt_max_len=[50], n_extra_words=[107])) + + @classmethod + def dummy_inputs(cls, params, init_case): + n_unique_src_words = 13 + scores = torch.randn( + (params["batch_size"] * params["tgt_max_len"], init_case["vocab_size"] + n_unique_src_words) + ) + scores = softmax(scores, dim=1) + align = torch.randint(0, n_unique_src_words, (params["batch_size"] * params["tgt_max_len"],)) + target = torch.randint(0, init_case["vocab_size"], (params["batch_size"] * params["tgt_max_len"],)) + target[0] = init_case["unk_index"] + target[1] = init_case["ignore_index"] + return scores, align, target + + @classmethod + def expected_shape(cls, params, init_case): + return (params["batch_size"] * params["tgt_max_len"],) + + def test_copy_loss_forward_shape(self): + for params, init_case in itertools.product(self.PARAMS, self.INIT_CASES): + loss = CopyGeneratorLoss(**init_case) + dummy_in = self.dummy_inputs(params, init_case) + res = loss(*dummy_in) + expected_shape = self.expected_shape(params, init_case) + self.assertEqual(res.shape, expected_shape, init_case.__str__()) + + def test_copy_loss_ignore_index_is_ignored(self): + for params, init_case in itertools.product(self.PARAMS, self.INIT_CASES): + loss = CopyGeneratorLoss(**init_case) + scores, align, target = self.dummy_inputs(params, init_case) + res = loss(scores, align, target) + should_be_ignored = (target == init_case["ignore_index"]).nonzero(as_tuple=False) + assert len(should_be_ignored) > 0 # otherwise not testing anything + self.assertTrue(res[should_be_ignored].allclose(torch.tensor(0.0))) + + def test_copy_loss_output_range_is_positive(self): + for params, init_case in itertools.product(self.PARAMS, self.INIT_CASES): + loss = CopyGeneratorLoss(**init_case) + dummy_in = self.dummy_inputs(params, init_case) + res = loss(*dummy_in) + self.assertTrue((res >= 0).all()) diff --git a/mammoth/tests/test_data_prepare.py b/onmt/tests/test_data_prepare.py similarity index 79% rename from mammoth/tests/test_data_prepare.py rename to onmt/tests/test_data_prepare.py index d1780982..9c7fdffb 100644 --- a/mammoth/tests/test_data_prepare.py +++ b/onmt/tests/test_data_prepare.py @@ -8,10 +8,10 @@ # import glob # import os # -# from mammoth.utils.parse import ArgumentParser -# from mammoth.opts import dynamic_prepare_opts -# from mammoth.bin.train import prepare_fields_transforms -# from mammoth.constants import CorpusName +# from onmt.utils.parse import ArgumentParser +# from onmt.opts import dynamic_prepare_opts +# from onmt.bin.train import prepare_fields_transforms +# from onmt.constants import CorpusName # # # SAVE_DATA_PREFIX = 'data/test_data_prepare' @@ -27,11 +27,11 @@ # '-tgt_vocab', 'data/vocab-train.tgt' # ] # -# opts = parser.parse_known_args(default_opts)[0] +# opt = parser.parse_known_args(default_opts)[0] # # Inject some dummy training options that may needed when build fields -# opts.copy_attn = False -# ArgumentParser.validate_prepare_opts(opts) -# return opts +# opt.copy_attn = False +# ArgumentParser.validate_prepare_opts(opt) +# return opt # # # default_opts = get_default_opts() @@ -40,15 +40,15 @@ # class TestData(unittest.TestCase): # def __init__(self, *args, **kwargs): # super(TestData, self).__init__(*args, **kwargs) -# self.opts = default_opts +# self.opt = default_opts # -# def dataset_build(self, opts): +# def dataset_build(self, opt): # try: -# prepare_fields_transforms(opts) +# prepare_fields_transforms(opt) # except SystemExit as err: # print(err) # except IOError as err: -# if opts.skip_empty_level != 'error': +# if opt.skip_empty_level != 'error': # raise err # else: # print(f"Catched IOError: {err}") @@ -56,10 +56,10 @@ # # Remove the generated *pt files. # for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'): # os.remove(pt) -# if self.opts.save_data: +# if self.opt.save_data: # # Remove the generated data samples # sample_path = os.path.join( -# os.path.dirname(self.opts.save_data), +# os.path.dirname(self.opt.save_data), # CorpusName.SAMPLE) # if os.path.exists(sample_path): # for f in glob.glob(sample_path + '/*'): @@ -78,12 +78,12 @@ # # def test_method(self): # if param_setting: -# opts = copy.deepcopy(self.opts) +# opt = copy.deepcopy(self.opt) # for param, setting in param_setting: -# setattr(opts, param, setting) +# setattr(opt, param, setting) # else: -# opts = self.opts -# getattr(self, methodname)(opts) +# opt = self.opt +# getattr(self, methodname)(opt) # if param_setting: # name = 'test_' + methodname + "_" + "_".join( # str(param_setting).split()) diff --git a/mammoth/tests/test_embeddings.py b/onmt/tests/test_embeddings.py similarity index 98% rename from mammoth/tests/test_embeddings.py rename to onmt/tests/test_embeddings.py index 9abac740..a152838d 100644 --- a/mammoth/tests/test_embeddings.py +++ b/onmt/tests/test_embeddings.py @@ -1,12 +1,12 @@ import unittest -from mammoth.modules.embeddings import Embeddings +from onmt.modules.embeddings import Embeddings import itertools from copy import deepcopy import torch -from mammoth.tests.utils_for_tests import product_dict +from onmt.tests.utils_for_tests import product_dict class TestEmbeddings(unittest.TestCase): diff --git a/mammoth/tests/test_greedy_search.py b/onmt/tests/test_greedy_search.py similarity index 99% rename from mammoth/tests/test_greedy_search.py rename to onmt/tests/test_greedy_search.py index 6c718e69..b32e016f 100644 --- a/mammoth/tests/test_greedy_search.py +++ b/onmt/tests/test_greedy_search.py @@ -1,5 +1,5 @@ import unittest -from mammoth.translate.greedy_search import GreedySearch +from onmt.translate.greedy_search import GreedySearch import torch diff --git a/mammoth/tests/test_model.pt b/onmt/tests/test_model.pt similarity index 100% rename from mammoth/tests/test_model.pt rename to onmt/tests/test_model.pt diff --git a/mammoth/tests/test_model2.pt b/onmt/tests/test_model2.pt similarity index 100% rename from mammoth/tests/test_model2.pt rename to onmt/tests/test_model2.pt diff --git a/mammoth/tests/test_model_lm.pt b/onmt/tests/test_model_lm.pt similarity index 100% rename from mammoth/tests/test_model_lm.pt rename to onmt/tests/test_model_lm.pt diff --git a/mammoth/tests/test_models.py b/onmt/tests/test_models.py similarity index 68% rename from mammoth/tests/test_models.py rename to onmt/tests/test_models.py index 089ca796..1d70a3cc 100644 --- a/mammoth/tests/test_models.py +++ b/onmt/tests/test_models.py @@ -3,24 +3,24 @@ import torch -import mammoth -import mammoth.opts -from mammoth.model_builder import build_embeddings, build_encoder, build_decoder -from mammoth.inputters.vocab import Vocab, DEFAULT_SPECIALS -from mammoth.utils.parse import ArgumentParser +import onmt +import onmt.opts +from onmt.model_builder import build_embeddings, build_encoder, build_decoder +from onmt.inputters.vocab import Vocab, DEFAULT_SPECIALS +from onmt.utils.parse import ArgumentParser parser = ArgumentParser(description='train.py') -mammoth.opts.model_opts(parser) -mammoth.opts._add_train_general_opts(parser) +onmt.opts.model_opts(parser) +onmt.opts._add_train_general_opts(parser) # -data option is required, but not used in this test, so dummy. -opts = parser.parse_known_args(['-tasks', 'dummy', '-node_rank', '0', '-model_dim', '500'])[0] +opt = parser.parse_known_args(['-data', 'dummy', '-node_rank', '0'])[0] class TestModel(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestModel, self).__init__(*args, **kwargs) - self.opts = opts + self.opt = opt def get_field(self): return Vocab(None, items=[], tag='dummy', specials=list(DEFAULT_SPECIALS)) @@ -32,77 +32,82 @@ def get_batch(self, source_l=3, bsize=1): test_length = torch.ones(bsize).fill_(source_l).long() return test_src, test_tgt, test_length - def embeddings_forward(self, opts, source_l=3, bsize=1): + def embeddings_forward(self, opt, source_l=3, bsize=1): ''' Tests if the embeddings works as expected args: - opts: set of options + opt: set of options source_l: Length of generated input sentence bsize: Batchsize of generated input ''' word_field = self.get_field() - emb = build_embeddings(opts, word_field) + emb = build_embeddings(opt, word_field) test_src, _, __ = self.get_batch(source_l=source_l, bsize=bsize) - if opts.decoder_type == 'transformer': + if opt.decoder_type == 'transformer': input = torch.cat([test_src, test_src], 0) res = emb(input) - compare_to = torch.zeros(source_l * 2, bsize, opts.model_dim) + compare_to = torch.zeros(source_l * 2, bsize, opt.src_word_vec_size) else: res = emb(test_src) - compare_to = torch.zeros(source_l, bsize, opts.model_dim) + compare_to = torch.zeros(source_l, bsize, opt.src_word_vec_size) self.assertEqual(res.size(), compare_to.size()) - def encoder_forward(self, opts, source_l=3, bsize=1): + def encoder_forward(self, opt, source_l=3, bsize=1): ''' Tests if the encoder works as expected args: - opts: set of options + opt: set of options source_l: Length of generated input sentence bsize: Batchsize of generated input ''' + if opt.rnn_size > 0: + opt.enc_rnn_size = opt.rnn_size word_field = self.get_field() - embeddings = build_embeddings(opts, word_field) - enc = build_encoder(opts, embeddings) + embeddings = build_embeddings(opt, word_field) + enc = build_encoder(opt, embeddings) test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) hidden_t, outputs, test_length = enc(test_src, test_length) # Initialize vectors to compare size with - test_hid = torch.zeros(self.opts.enc_layers, bsize, opts.model_dim) - test_out = torch.zeros(source_l, bsize, opts.model_dim) + test_hid = torch.zeros(self.opt.enc_layers, bsize, opt.enc_rnn_size) + test_out = torch.zeros(source_l, bsize, opt.dec_rnn_size) # Ensure correct sizes and types self.assertEqual(test_hid.size(), hidden_t[0].size(), hidden_t[1].size()) self.assertEqual(test_out.size(), outputs.size()) self.assertEqual(type(outputs), torch.Tensor) - def nmtmodel_forward(self, opts, source_l=3, bsize=1): + def nmtmodel_forward(self, opt, source_l=3, bsize=1): """ - Creates a nmtmodel with a custom opts function. + Creates a nmtmodel with a custom opt function. Forwards a testbatch and checks output size. Args: - opts: Namespace with options + opt: Namespace with options source_l: length of input sequence bsize: batchsize """ + if opt.rnn_size > 0: + opt.enc_rnn_size = opt.rnn_size + opt.dec_rnn_size = opt.rnn_size word_field = self.get_field() - embeddings = build_embeddings(opts, word_field) - enc = build_encoder(opts, embeddings) + embeddings = build_embeddings(opt, word_field) + enc = build_encoder(opt, embeddings) - embeddings = build_embeddings(opts, word_field, for_encoder=False) - dec = build_decoder(opts, embeddings) + embeddings = build_embeddings(opt, word_field, for_encoder=False) + dec = build_decoder(opt, embeddings) - model = mammoth.models.model.NMTModel(enc, dec) + model = onmt.models.model.NMTModel(enc, dec) test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) outputs, attn = model(test_src, test_tgt, test_length) - outputsize = torch.zeros(source_l - 1, bsize, opts.model_dim) + outputsize = torch.zeros(source_l - 1, bsize, opt.dec_rnn_size) # Make sure that output has the correct size and type self.assertEqual(outputs.size(), outputsize.size()) self.assertEqual(type(outputs), torch.Tensor) @@ -118,12 +123,12 @@ def _add_test(param_setting, methodname): """ def test_method(self): - opts = copy.deepcopy(self.opts) + opt = copy.deepcopy(self.opt) if param_setting: for param, setting in param_setting: - setattr(opts, param, setting) - ArgumentParser.update_model_opts(opts) - getattr(self, methodname)(opts) + setattr(opt, param, setting) + ArgumentParser.update_model_opts(opt) + getattr(self, methodname)(opt) if param_setting: name = 'test_' + methodname + "_" + "_".join(str(param_setting).split()) @@ -136,7 +141,7 @@ def test_method(self): ''' TEST PARAMETERS ''' -opts.brnn = False +opt.brnn = False # FIXME: Most tests disabled: FoTraNMT only supports Transformer test_embeddings = [ @@ -151,7 +156,7 @@ def test_method(self): tests_encoder = [ # [], # [('encoder_type', 'mean')], - # [('encoder_type', 'transformer'), ('word_vec_size', 16), ('model_dim', 16)], + # [('encoder_type', 'transformer'), ('word_vec_size', 16), ('rnn_size', 16)], # [], ] @@ -168,14 +173,14 @@ def test_method(self): ('encoder_type', 'transformer'), ('src_word_vec_size', 16), ('tgt_word_vec_size', 16), - ('model_dim', 16), + ('rnn_size', 16), ], [ ('decoder_type', 'transformer'), ('encoder_type', 'transformer'), ('src_word_vec_size', 16), ('tgt_word_vec_size', 16), - ('model_dim', 16), + ('rnn_size', 16), ('position_encoding', True), ], # [('coverage_attn', True)], @@ -193,6 +198,10 @@ def test_method(self): # [], ] +if onmt.models.sru.check_sru_requirement(): + # """ Only do SRU test if requirment is safisfied. """ + # SRU doesn't support input_feed. + tests_nmtmodel.append([('rnn_type', 'SRU'), ('input_feed', 0)]) # ## FIXME: Broken in FoTraNMT # for p in tests_nmtmodel: diff --git a/mammoth/tests/test_models.sh b/onmt/tests/test_models.sh similarity index 100% rename from mammoth/tests/test_models.sh rename to onmt/tests/test_models.sh diff --git a/mammoth/tests/test_simple.py b/onmt/tests/test_simple.py similarity index 50% rename from mammoth/tests/test_simple.py rename to onmt/tests/test_simple.py index abdafbda..bd607e57 100644 --- a/mammoth/tests/test_simple.py +++ b/onmt/tests/test_simple.py @@ -1,6 +1,6 @@ -import mammoth +import onmt def test_load(): - mammoth + onmt pass diff --git a/onmt/tests/test_structured_attention.py b/onmt/tests/test_structured_attention.py new file mode 100644 index 00000000..543be5b8 --- /dev/null +++ b/onmt/tests/test_structured_attention.py @@ -0,0 +1,12 @@ +import unittest +from onmt.modules.structured_attention import MatrixTree + +import torch + + +class TestStructuredAttention(unittest.TestCase): + def test_matrix_tree_marg_pdfs_sum_to_1(self): + dtree = MatrixTree() + q = torch.rand(1, 5, 5) + marg = dtree.forward(q) + self.assertTrue(marg.sum(1).allclose(torch.tensor(1.0))) diff --git a/mammoth/tests/test_subword_marker.py b/onmt/tests/test_subword_marker.py similarity index 98% rename from mammoth/tests/test_subword_marker.py rename to onmt/tests/test_subword_marker.py index 63dd57c9..afa17fcf 100644 --- a/mammoth/tests/test_subword_marker.py +++ b/onmt/tests/test_subword_marker.py @@ -1,8 +1,8 @@ import unittest -from mammoth.transforms.denoising import word_start_finder -from mammoth.utils.alignment import subword_map_by_joiner, subword_map_by_spacer -from mammoth.constants import SubwordMarker +from onmt.transforms.denoising import word_start_finder +from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer +from onmt.constants import SubwordMarker class TestWordStartFinder(unittest.TestCase): diff --git a/mammoth/tests/test_task_distribution_strategy.py b/onmt/tests/test_task_distribution_strategy.py similarity index 86% rename from mammoth/tests/test_task_distribution_strategy.py rename to onmt/tests/test_task_distribution_strategy.py index 612d9140..aba90222 100644 --- a/mammoth/tests/test_task_distribution_strategy.py +++ b/onmt/tests/test_task_distribution_strategy.py @@ -1,11 +1,11 @@ import pytest from argparse import Namespace -from mammoth.distributed.tasks import WeightedSamplingTaskDistributionStrategy, RoundRobinTaskDistributionStrategy +from onmt.utils.distributed import WeightedSamplingTaskDistributionStrategy, RoundRobinTaskDistributionStrategy def test_weights_all_zero(): - opts = Namespace(data={ + opt = Namespace(data={ 'a': { 'weight': 0, 'introduce_at_training_step': 0, @@ -20,12 +20,12 @@ def test_weights_all_zero(): }, }) with pytest.raises(ValueError) as exc_info: - WeightedSamplingTaskDistributionStrategy.from_opts(['a', 'b'], opts) + WeightedSamplingTaskDistributionStrategy.from_opt(['a', 'b'], opt) assert 'Can not set "weight" of all corpora on a device to zero' in str(exc_info.value) def test_weights_all_postponed(): - opts = Namespace(data={ + opt = Namespace(data={ 'a': { 'weight': 1, 'introduce_at_training_step': 1, @@ -40,12 +40,12 @@ def test_weights_all_postponed(): }, }) with pytest.raises(ValueError) as exc_info: - WeightedSamplingTaskDistributionStrategy.from_opts(['a', 'b'], opts) + WeightedSamplingTaskDistributionStrategy.from_opt(['a', 'b'], opt) assert 'Can not set "introduce_at_training_step" of all corpora on a device to nonzero' in str(exc_info.value) def test_invalid_curriculum(): - opts = Namespace(data={ + opt = Namespace(data={ # 'a' disabled by weight 'a': { 'weight': 0, @@ -62,12 +62,12 @@ def test_invalid_curriculum(): }, }) with pytest.raises(ValueError) as exc_info: - WeightedSamplingTaskDistributionStrategy.from_opts(['a', 'b'], opts) + WeightedSamplingTaskDistributionStrategy.from_opt(['a', 'b'], opt) assert 'Invalid curriculum' in str(exc_info.value) def test_sampling_task_distribution_strategy(): - opts = Namespace(data={ + opt = Namespace(data={ # 'a' disabled by weight 'a': { 'weight': 0, @@ -89,7 +89,7 @@ def test_sampling_task_distribution_strategy(): 'introduce_at_training_step': 0, }, }) - strategy = WeightedSamplingTaskDistributionStrategy.from_opts(['a', 'b', 'c'], opts) + strategy = WeightedSamplingTaskDistributionStrategy.from_opt(['a', 'b', 'c'], opt) all_samples = [] n_samples = 10 n_batches = 1000 diff --git a/mammoth/tests/test_task_queue_manager.py b/onmt/tests/test_task_queue_manager.py similarity index 88% rename from mammoth/tests/test_task_queue_manager.py rename to onmt/tests/test_task_queue_manager.py index 17155080..efc75edd 100644 --- a/mammoth/tests/test_task_queue_manager.py +++ b/onmt/tests/test_task_queue_manager.py @@ -3,7 +3,7 @@ from collections import OrderedDict from unittest.mock import MagicMock -from mammoth.distributed import TaskQueueManager, WorldContext +from onmt.utils.distributed import TaskQueueManager, WorldContext def test_init_minimal(): @@ -20,9 +20,9 @@ def test_init_minimal(): 'train_c-d': {'path_src': 'dummy', 'path_tgt': 'dummy', 'src_tgt': 'c-d'}, } } - opts = Namespace(**opt_dict) - world_context = WorldContext.from_opts(opts) - task_queue_manager = TaskQueueManager.from_opts(opts, world_context) + opt = Namespace(**opt_dict) + world_context = WorldContext.from_opt(opt) + task_queue_manager = TaskQueueManager.from_opt(opt, world_context) assert world_context.is_gpu() assert world_context.is_distributed() assert len(task_queue_manager.tasks) == 2 @@ -95,15 +95,15 @@ def create_basic_task_queue_manager(): }, } } - opts = Namespace(**opt_dict) - world_context = WorldContext.from_opts(opts) - task_queue_manager = TaskQueueManager.from_opts(opts, world_context) - return task_queue_manager, opts + opt = Namespace(**opt_dict) + world_context = WorldContext.from_opt(opt) + task_queue_manager = TaskQueueManager.from_opt(opt, world_context) + return task_queue_manager, opt def test_init_basic(): - global_task_queue_manager, opts = create_basic_task_queue_manager() - task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opts=opts) + global_task_queue_manager, opt = create_basic_task_queue_manager() + task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opt=opt) world_context = task_queue_manager.world_context assert world_context.is_gpu() assert world_context.is_distributed() @@ -129,7 +129,7 @@ def __call__(self, sorted_global_ranks): self.group_idx += 1 return result - global_task_queue_manager, opts = create_basic_task_queue_manager() + global_task_queue_manager, opt = create_basic_task_queue_manager() all_groups = global_task_queue_manager.create_all_distributed_groups(new_group_func=MockGroup()) assert all_groups == { 'src_emb': OrderedDict({ @@ -157,8 +157,8 @@ def __call__(self, sorted_global_ranks): self.group_idx += 1 return result - global_task_queue_manager, opts = create_basic_task_queue_manager() - task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opts=opts) + global_task_queue_manager, opt = create_basic_task_queue_manager() + task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opt=opt) my_groups = task_queue_manager.get_distributed_groups(new_group_func=MockGroup()) assert my_groups == { 'encoder': OrderedDict({ @@ -196,10 +196,10 @@ def test_cpu_distributed_groups(): }, } } - opts = Namespace(**opt_dict) - world_context = WorldContext.from_opts(opts) - global_task_queue_manager = TaskQueueManager.from_opts(opts, world_context) - task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opts=opts) + opt = Namespace(**opt_dict) + world_context = WorldContext.from_opt(opt) + global_task_queue_manager = TaskQueueManager.from_opt(opt, world_context) + task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opt=opt) new_group_func = MagicMock().new_group_func my_groups = task_queue_manager.get_distributed_groups(new_group_func=new_group_func) # No groups should be created when running on CPU @@ -248,10 +248,10 @@ def test_distributed_groups_no_encoder_group(): }, } } - opts = Namespace(**opt_dict) - world_context = WorldContext.from_opts(opts) - global_task_queue_manager = TaskQueueManager.from_opts(opts, world_context) - task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opts=opts) + opt = Namespace(**opt_dict) + world_context = WorldContext.from_opt(opt) + global_task_queue_manager = TaskQueueManager.from_opt(opt, world_context) + task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opt=opt) new_group_func = MagicMock().new_group_func my_groups = task_queue_manager.get_distributed_groups(new_group_func=new_group_func) # No groups should be created: @@ -269,20 +269,20 @@ def test_distributed_groups_no_encoder_group(): # (side, lang): f'{side} {lang}' for (side, lang) in # [('src', 'a'), ('src', 'c'), ('src', 'e'), ('tgt', 'b'), ('tgt', 'd')] # } -# global_task_queue_manager, opts = create_basic_task_queue_manager() -# task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opts=opts) +# global_task_queue_manager, opt = create_basic_task_queue_manager() +# task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opt=opt) # fields = task_queue_manager.get_fields('src', mock_fields) # assert fields == [('src', 'a', None, 'src a')] # fields = task_queue_manager.get_fields('tgt', mock_fields) # assert fields == [('tgt', 'b', None, 'tgt b')] # -# task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opts=opts) +# task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opt=opt) # fields = task_queue_manager.get_fields('src', mock_fields) # assert fields == [('src', 'c', None, 'src c'), ('src', 'a', 'x', 'src a')] # fields = task_queue_manager.get_fields('tgt', mock_fields) # assert fields == [('tgt', 'd', None, 'tgt d')] # -# task_queue_manager = global_task_queue_manager.global_to_local(node_rank=1, local_rank=0, opts=opts) +# task_queue_manager = global_task_queue_manager.global_to_local(node_rank=1, local_rank=0, opt=opt) # fields = task_queue_manager.get_fields('src', mock_fields) # assert fields == [('src', 'e', None, 'src e')] # fields = task_queue_manager.get_fields('tgt', mock_fields) @@ -290,8 +290,8 @@ def test_distributed_groups_no_encoder_group(): def test_basic_getters(): - global_task_queue_manager, opts = create_basic_task_queue_manager() - task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opts=opts) + global_task_queue_manager, opt = create_basic_task_queue_manager() + task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=0, opt=opt) encoders = list(task_queue_manager.get_encoders(0)) assert encoders == ['x'] decoders = list(task_queue_manager.get_decoders(0)) @@ -303,7 +303,7 @@ def test_basic_getters(): generators = list(task_queue_manager.get_generators()) assert generators == ['b'] - task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opts=opts) + task_queue_manager = global_task_queue_manager.global_to_local(node_rank=0, local_rank=1, opt=opt) encoders = list(task_queue_manager.get_encoders(0)) assert encoders == ['xx', 'x'] decoders = list(task_queue_manager.get_decoders(0)) diff --git a/mammoth/tests/test_text_dataset.py b/onmt/tests/test_text_dataset.py similarity index 99% rename from mammoth/tests/test_text_dataset.py rename to onmt/tests/test_text_dataset.py index 9d6f3b56..0fffe0ca 100644 --- a/mammoth/tests/test_text_dataset.py +++ b/onmt/tests/test_text_dataset.py @@ -7,7 +7,7 @@ # # # from torchtext.legacy.data import Field # -# from mammoth.tests.utils_for_tests import product_dict +# from onmt.tests.utils_for_tests import product_dict # # # class TestTextMultiField(unittest.TestCase): diff --git a/mammoth/tests/test_transform.py b/onmt/tests/test_transform.py similarity index 91% rename from mammoth/tests/test_transform.py rename to onmt/tests/test_transform.py index e53ec3e5..32e25f83 100644 --- a/mammoth/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -5,13 +5,13 @@ import yaml import math from argparse import Namespace -from mammoth.transforms import ( +from onmt.transforms import ( get_transforms_cls, get_specials, make_transforms, TransformPipe, ) -from mammoth.transforms.denoising import BARTNoising +from onmt.transforms.denoising import BARTNoising class TestTransform(unittest.TestCase): @@ -31,13 +31,13 @@ def test_transform_register(self): def test_vocab_required_transform(self): transforms_cls = get_transforms_cls(["denoising", "switchout"]) - opts = Namespace(seed=-1, switchout_temperature=1.0) + opt = Namespace(seed=-1, switchout_temperature=1.0) # transforms that require vocab will not create if not provide vocab - transforms = make_transforms(opts, transforms_cls, vocabs=None, task=None) + transforms = make_transforms(opt, transforms_cls, vocabs=None, task=None) self.assertEqual(len(transforms), 0) with self.assertRaises(ValueError): - transforms_cls["switchout"](opts).warm_up(vocabs=None) - transforms_cls["denoising"](opts).warm_up(vocabs=None) + transforms_cls["switchout"](opt).warm_up(vocabs=None) + transforms_cls["denoising"](opt).warm_up(vocabs=None) def test_transform_specials(self): transforms_cls = get_transforms_cls(["prefix"]) @@ -52,8 +52,8 @@ def test_transform_specials(self): tgt_prefix: "⦅_pf_tgt⦆" """ ) - opts = Namespace(tasks=corpora) - specials = get_specials(opts, transforms_cls) + opt = Namespace(data=corpora) + specials = get_specials(opt, transforms_cls) specials_expected = {"src": {"⦅_pf_src⦆"}, "tgt": {"⦅_pf_tgt⦆"}} self.assertEqual(specials, specials_expected) @@ -71,13 +71,13 @@ def test_transform_pipe(self): tgt_prefix: "⦅_pf_tgt⦆" """ ) - opts = Namespace(tasks=corpora, seed=-1) - prefix_transform = prefix_cls(opts) + opt = Namespace(data=corpora, seed=-1) + prefix_transform = prefix_cls(opt) prefix_transform.warm_up() # 2. Init second transform in the pipe filter_cls = get_transforms_cls(["filtertoolong"])["filtertoolong"] - opts = Namespace(src_seq_length=4, tgt_seq_length=4) - filter_transform = filter_cls(opts) + opt = Namespace(src_seq_length=4, tgt_seq_length=4) + filter_transform = filter_cls(opt) # 3. Sequential combine them into a transform pipe transform_pipe = TransformPipe.build_from([prefix_transform, filter_transform]) ex = { @@ -110,8 +110,8 @@ def test_prefix(self): tgt_prefix: "⦅_pf_tgt⦆" """ ) - opts = Namespace(tasks=corpora, seed=-1) - prefix_transform = prefix_cls(opts) + opt = Namespace(data=corpora, seed=-1) + prefix_transform = prefix_cls(opt) prefix_transform.warm_up() self.assertIn("trainset", prefix_transform.prefix_dict) @@ -128,8 +128,8 @@ def test_prefix(self): def test_filter_too_long(self): filter_cls = get_transforms_cls(["filtertoolong"])["filtertoolong"] - opts = Namespace(src_seq_length=100, tgt_seq_length=100) - filter_transform = filter_cls(opts) + opt = Namespace(src_seq_length=100, tgt_seq_length=100) + filter_transform = filter_cls(opt) # filter_transform.warm_up() ex_in = { "src": ["Hello", "world", "."], @@ -162,9 +162,9 @@ def setUpClass(cls): def test_bpe(self): bpe_cls = get_transforms_cls(["bpe"])["bpe"] - opts = Namespace(**self.base_opts) - bpe_cls._validate_options(opts) - bpe_transform = bpe_cls(opts) + opt = Namespace(**self.base_opts) + bpe_cls._validate_options(opt) + bpe_transform = bpe_cls(opt) bpe_transform.warm_up() ex = { "src": ["Hello", "world", "."], @@ -212,9 +212,9 @@ def test_sentencepiece(self): base_opt = copy.copy(self.base_opts) base_opt["src_subword_model"] = "data/sample.sp.model" base_opt["tgt_subword_model"] = "data/sample.sp.model" - opts = Namespace(**base_opt) - sp_cls._validate_options(opts) - sp_transform = sp_cls(opts) + opt = Namespace(**base_opt) + sp_cls._validate_options(opt) + sp_transform = sp_cls(opt) sp_transform.warm_up() ex = { "src": ["Hello", "world", "."], @@ -245,9 +245,9 @@ def test_pyonmttok_bpe(self): onmt_args = "{'mode': 'space', 'joiner_annotate': True}" base_opt["src_onmttok_kwargs"] = onmt_args base_opt["tgt_onmttok_kwargs"] = onmt_args - opts = Namespace(**base_opt) - onmttok_cls._validate_options(opts) - onmttok_transform = onmttok_cls(opts) + opt = Namespace(**base_opt) + onmttok_cls._validate_options(opt) + onmttok_transform = onmttok_cls(opt) onmttok_transform.warm_up() ex = { "src": ["Hello", "world", "."], @@ -270,9 +270,9 @@ def test_pyonmttok_sp(self): onmt_args = "{'mode': 'none', 'spacer_annotate': True}" base_opt["src_onmttok_kwargs"] = onmt_args base_opt["tgt_onmttok_kwargs"] = onmt_args - opts = Namespace(**base_opt) - onmttok_cls._validate_options(opts) - onmttok_transform = onmttok_cls(opts) + opt = Namespace(**base_opt) + onmttok_cls._validate_options(opt) + onmttok_transform = onmttok_cls(opt) onmttok_transform.warm_up() ex = { "src": ["Hello", "world", "."], @@ -289,8 +289,8 @@ def test_pyonmttok_sp(self): class TestSamplingTransform(unittest.TestCase): def test_tokendrop(self): tokendrop_cls = get_transforms_cls(["tokendrop"])["tokendrop"] - opts = Namespace(seed=3434, tokendrop_temperature=0.1) - tokendrop_transform = tokendrop_cls(opts) + opt = Namespace(seed=3434, tokendrop_temperature=0.1) + tokendrop_transform = tokendrop_cls(opt) tokendrop_transform.warm_up() ex = { "src": ["Hello", ",", "world", "."], @@ -305,8 +305,8 @@ def test_tokendrop(self): def test_tokenmask(self): tokenmask_cls = get_transforms_cls(["tokenmask"])["tokenmask"] - opts = Namespace(seed=3434, tokenmask_temperature=0.1) - tokenmask_transform = tokenmask_cls(opts) + opt = Namespace(seed=3434, tokenmask_temperature=0.1) + tokenmask_transform = tokenmask_cls(opt) tokenmask_transform.warm_up() ex = { "src": ["Hello", ",", "world", "."], @@ -321,8 +321,8 @@ def test_tokenmask(self): def test_switchout(self): switchout_cls = get_transforms_cls(["switchout"])["switchout"] - opts = Namespace(seed=3434, switchout_temperature=0.1) - switchout_transform = switchout_cls(opts) + opt = Namespace(seed=3434, switchout_temperature=0.1) + switchout_transform = switchout_cls(opt) with self.assertRaises(ValueError): # require vocabs to warm_up switchout_transform.warm_up(vocabs=None) @@ -518,16 +518,16 @@ def test_span_infilling(self): def test_vocab_required_transform(self): transforms_cls = get_transforms_cls(["denoising"]) - opts = Namespace(random_ratio=1, denoising_objective='mass') + opt = Namespace(random_ratio=1, denoising_objective='mass') with self.assertRaises(ValueError): - make_transforms(opts, transforms_cls, vocabs=None, task=None) + make_transforms(opt, transforms_cls, vocabs=None, task=None) class TestFeaturesTransform(unittest.TestCase): def test_inferfeats(self): inferfeats_cls = get_transforms_cls(["inferfeats"])["inferfeats"] - opts = Namespace(reversible_tokenization="joiner", prior_tokenization=False) - inferfeats_transform = inferfeats_cls(opts) + opt = Namespace(reversible_tokenization="joiner", prior_tokenization=False) + inferfeats_transform = inferfeats_cls(opt) ex_in = { "src": [ diff --git a/mammoth/tests/test_translation_server.py b/onmt/tests/test_translation_server.py similarity index 85% rename from mammoth/tests/test_translation_server.py rename to onmt/tests/test_translation_server.py index c639f3b5..07489a25 100644 --- a/mammoth/tests/test_translation_server.py +++ b/onmt/tests/test_translation_server.py @@ -1,12 +1,12 @@ import unittest -from mammoth.translate.translation_server import ServerModel, TranslationServer +from onmt.translate.translation_server import ServerModel, TranslationServer import os from textwrap import dedent import torch -from mammoth.translate.translator import Translator +from onmt.translate.translator import Translator TEST_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -16,9 +16,9 @@ class TestServerModel(unittest.TestCase): @unittest.skip('Broken in FoTraNMT') # FIXME def test_deferred_loading_model_and_unload(self): model_id = 0 - opts = {"models": ["test_model.pt"]} + opt = {"models": ["test_model.pt"]} model_root = TEST_DIR - sm = ServerModel(opts, model_id, model_root=model_root, load=False) + sm = ServerModel(opt, model_id, model_root=model_root, load=False) self.assertFalse(sm.loaded) sm.load() self.assertTrue(sm.loaded) @@ -29,9 +29,9 @@ def test_deferred_loading_model_and_unload(self): @unittest.skip('Broken in FoTraNMT') # FIXME def test_load_model_on_init_and_unload(self): model_id = 0 - opts = {"models": ["test_model.pt"]} + opt = {"models": ["test_model.pt"]} model_root = TEST_DIR - sm = ServerModel(opts, model_id, model_root=model_root, load=True) + sm = ServerModel(opt, model_id, model_root=model_root, load=True) self.assertTrue(sm.loaded) self.assertIsInstance(sm.translator, Translator) sm.unload() @@ -40,18 +40,18 @@ def test_load_model_on_init_and_unload(self): @unittest.skip('Broken in FoTraNMT') # FIXME def test_tokenizing_with_no_tokenizer_fails(self): model_id = 0 - opts = {"models": ["test_model.pt"]} + opt = {"models": ["test_model.pt"]} model_root = TEST_DIR - sm = ServerModel(opts, model_id, model_root=model_root, load=True) + sm = ServerModel(opt, model_id, model_root=model_root, load=True) with self.assertRaises(ValueError): sm.tokenize("hello world") @unittest.skip('Broken in FoTraNMT') # FIXME def test_detokenizing_with_no_tokenizer_fails(self): model_id = 0 - opts = {"models": ["test_model.pt"]} + opt = {"models": ["test_model.pt"]} model_root = TEST_DIR - sm = ServerModel(opts, model_id, model_root=model_root, load=True) + sm = ServerModel(opt, model_id, model_root=model_root, load=True) with self.assertRaises(ValueError): sm.detokenize("hello world") @@ -60,9 +60,9 @@ def test_detokenizing_with_no_tokenizer_fails(self): def test_moving_to_gpu_and_back(self): torch.cuda.set_device(torch.device("cuda", 0)) model_id = 0 - opts = {"models": ["test_model.pt"]} + opt = {"models": ["test_model.pt"]} model_root = TEST_DIR - sm = ServerModel(opts, model_id, model_root=model_root, load=True) + sm = ServerModel(opt, model_id, model_root=model_root, load=True) for p in sm.translator.model.parameters(): self.assertEqual(p.device.type, "cpu") sm.to_gpu() @@ -76,9 +76,9 @@ def test_moving_to_gpu_and_back(self): def test_initialize_on_gpu_and_move_back(self): torch.cuda.set_device(torch.device("cuda", 0)) model_id = 0 - opts = {"models": ["test_model.pt"], "gpu": 0} + opt = {"models": ["test_model.pt"], "gpu": 0} model_root = TEST_DIR - sm = ServerModel(opts, model_id, model_root=model_root, load=True) + sm = ServerModel(opt, model_id, model_root=model_root, load=True) for p in sm.translator.model.parameters(): self.assertEqual(p.device.type, "cuda") self.assertEqual(p.device.index, 0) @@ -95,9 +95,9 @@ def test_initialize_on_gpu_and_move_back(self): def test_initialize_on_nonzero_gpu_and_back(self): torch.cuda.set_device(torch.device("cuda", 1)) model_id = 0 - opts = {"models": ["test_model.pt"], "gpu": 1} + opt = {"models": ["test_model.pt"], "gpu": 1} model_root = TEST_DIR - sm = ServerModel(opts, model_id, model_root=model_root, load=True) + sm = ServerModel(opt, model_id, model_root=model_root, load=True) for p in sm.translator.model.parameters(): self.assertEqual(p.device.type, "cuda") self.assertEqual(p.device.index, 1) @@ -112,9 +112,9 @@ def test_initialize_on_nonzero_gpu_and_back(self): @unittest.skip('Broken in FoTraNMT') # FIXME def test_run(self): model_id = 0 - opts = {"models": ["test_model.pt"]} + opt = {"models": ["test_model.pt"]} model_root = TEST_DIR - sm = ServerModel(opts, model_id, model_root=model_root, load=True) + sm = ServerModel(opt, model_id, model_root=model_root, load=True) inp = [{"src": "hello how are you today"}, {"src": "good morning to you ."}] results, scores, n_best, time, aligns = sm.run(inp) self.assertIsInstance(results, list) @@ -160,7 +160,7 @@ def write(self, cfg): "timeout": -1, "on_timeout": "to_cpu", "load": false, - "opts": { + "opt": { "beam_size": 5 } } @@ -188,7 +188,7 @@ def test_start_without_initial_loading(self): "timeout": -1, "on_timeout": "to_cpu", "load": true, - "opts": { + "opt": { "beam_size": 5 } } @@ -217,7 +217,7 @@ def test_start_with_initial_loading(self): "timeout": -1, "on_timeout": "to_cpu", "load": true, - "opts": { + "opt": { "beam_size": 5 } }, @@ -227,7 +227,7 @@ def test_start_with_initial_loading(self): "timeout": -1, "on_timeout": "to_cpu", "load": false, - "opts": { + "opt": { "beam_size": 5 } } diff --git a/mammoth/tests/test_translator.py b/onmt/tests/test_translator.py similarity index 96% rename from mammoth/tests/test_translator.py rename to onmt/tests/test_translator.py index 8107e2e6..78ffe60b 100644 --- a/mammoth/tests/test_translator.py +++ b/onmt/tests/test_translator.py @@ -1,5 +1,5 @@ import unittest -from mammoth.translate import GeneratorLM +from onmt.translate import GeneratorLM import torch diff --git a/mammoth/tests/utils_for_tests.py b/onmt/tests/utils_for_tests.py similarity index 100% rename from mammoth/tests/utils_for_tests.py rename to onmt/tests/utils_for_tests.py diff --git a/mammoth/train_single.py b/onmt/train_single.py similarity index 74% rename from mammoth/train_single.py rename to onmt/train_single.py index f036f0d9..b396c5f4 100644 --- a/mammoth/train_single.py +++ b/onmt/train_single.py @@ -3,54 +3,53 @@ import torch import time -from mammoth.model_builder import build_model -from mammoth.utils.optimizers import Optimizer -from mammoth.utils.misc import set_random_seed -from mammoth.trainer import build_trainer -from mammoth.models import build_model_saver -from mammoth.utils.logging import init_logger, logger -from mammoth.utils.parse import ArgumentParser +from onmt.model_builder import build_model +from onmt.utils.optimizers import Optimizer +from onmt.utils.misc import set_random_seed +from onmt.trainer import build_trainer +from onmt.models import build_model_saver +from onmt.utils.logging import init_logger, logger +from onmt.utils.parse import ArgumentParser -from mammoth.distributed import broadcast_tensors -from mammoth.inputters import DynamicDatasetIter -from mammoth.transforms import get_transforms_cls +from onmt.utils.distributed import broadcast_tensors +from onmt.inputters import DynamicDatasetIter +from onmt.transforms import get_transforms_cls -def configure_process(opts, device_id): +def configure_process(opt, device_id): logger.info("logger set device {} ".format(device_id)) if device_id >= 0: torch.cuda.set_device(device_id) - set_random_seed(opts.seed, device_id >= 0) + set_random_seed(opt.seed, device_id >= 0) -def _get_model_opts(opts, checkpoint=None): - """Get `model_opts` to build model, may load from `checkpoint` if any.""" +def _get_model_opts(opt, checkpoint=None): + """Get `model_opt` to build model, may load from `checkpoint` if any.""" if checkpoint is not None: - model_opts = ArgumentParser.ckpt_model_opts(checkpoint["opts"]) - ArgumentParser.update_model_opts(model_opts) - ArgumentParser.validate_model_opts(model_opts) - if opts.tensorboard_log_dir == model_opts.tensorboard_log_dir and \ - hasattr(model_opts, 'tensorboard_log_dir_dated'): + model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) + ArgumentParser.update_model_opts(model_opt) + ArgumentParser.validate_model_opts(model_opt) + if opt.tensorboard_log_dir == model_opt.tensorboard_log_dir and hasattr(model_opt, 'tensorboard_log_dir_dated'): # ensure tensorboard output is written in the directory # of previous checkpoints - opts.tensorboard_log_dir_dated = model_opts.tensorboard_log_dir_dated + opt.tensorboard_log_dir_dated = model_opt.tensorboard_log_dir_dated # Override checkpoint's update_embeddings as it defaults to false - model_opts.update_vocab = opts.update_vocab + model_opt.update_vocab = opt.update_vocab else: - model_opts = opts - return model_opts + model_opt = opt + return model_opt -def _build_valid_iter(opts, vocabs_dict, transforms_cls, task_queue_manager): +def _build_valid_iter(opt, vocabs_dict, transforms_cls, task_queue_manager): """Build iterator used for validation.""" - if not any(opts.tasks[corpus_id].get('path_valid_src', False) for corpus_id in opts.tasks.keys()): + if not any(opt.data[corpus_id].get('path_valid_src', False) for corpus_id in opt.data.keys()): return None logger.info("creating validation iterator") valid_iter = DynamicDatasetIter.from_opts( task_queue_manager=task_queue_manager, transforms_cls=transforms_cls, vocabs_dict=vocabs_dict, - opts=opts, + opts=opt, is_train=False, ) return valid_iter @@ -108,7 +107,7 @@ def init_distributed(model, task_queue_manager): def main( - opts, + opt, vocabs_dict, device_context, error_queue=None, @@ -117,26 +116,26 @@ def main( task_queue_manager=None, ): """Start training on `device_id`.""" - # NOTE: It's important that ``opts`` has been validated and updated + # NOTE: It's important that ``opt`` has been validated and updated # at this point. # N.B: task_queue_manager is already local - init_logger(opts.log_file, gpu_id=device_context.id) + init_logger(opt.log_file, gpu_id=device_context.id) if device_context.is_distributed(): sleep_s = device_context.local_rank * 3 logger.warning(f'sleeping {sleep_s}s to alleviate ROCm deadlock') time.sleep(sleep_s) - configure_process(opts, device_context.local_rank) + configure_process(opt, device_context.local_rank) gpu_rank_t = torch.distributed.get_rank() logger.info("RANK GPU FROM TORCH %s", str(gpu_rank_t)) - transforms_cls = get_transforms_cls(opts._all_transform) + transforms_cls = get_transforms_cls(opt._all_transform) checkpoint = None - model_opts = _get_model_opts(opts, checkpoint=checkpoint) + model_opt = _get_model_opts(opt, checkpoint=checkpoint) # Build model. - model, generators_md = build_model(model_opts, opts, vocabs_dict, task_queue_manager, checkpoint) + model, generators_md = build_model(model_opt, opt, vocabs_dict, task_queue_manager, checkpoint) logger.info("{} - Init model".format(device_context.id)) if device_context.is_distributed(): @@ -150,19 +149,19 @@ def main( # Build optimizer. logger.info("{} - Build optimizer".format(device_context.id)) - optim = Optimizer.from_opts( + optim = Optimizer.from_opt( model, - opts, + opt, task_queue_manager=task_queue_manager, checkpoint=checkpoint, ) # Build model saver - model_saver = build_model_saver(model_opts, opts, model, vocabs_dict, optim, device_context) + model_saver = build_model_saver(model_opt, opt, model, vocabs_dict, optim, device_context) logger.info("{} - Build trainer".format(device_context.id)) trainer = build_trainer( - opts, + opt, device_context, model, vocabs_dict, @@ -178,7 +177,7 @@ def main( task_queue_manager=task_queue_manager, transforms_cls=transforms_cls, vocabs_dict=vocabs_dict, - opts=opts, + opts=opt, is_train=True, ) # TODO: check that IterOnDevice is unnecessary here; corpora should be already on device @@ -199,15 +198,15 @@ def _train_iter(): train_iter = _train_iter() # train_iter = iter_on_device(train_iter, device_context) logger.info("Device {} - Valid iter".format(device_context.id)) - valid_iter = _build_valid_iter(opts, vocabs_dict, transforms_cls, task_queue_manager) + valid_iter = _build_valid_iter(opt, vocabs_dict, transforms_cls, task_queue_manager) - if len(opts.gpu_ranks): + if len(opt.gpu_ranks): if device_context.is_master(): - logger.info('Starting training on GPU: %s' % opts.gpu_ranks) + logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') - train_steps = opts.train_steps - if opts.single_pass and train_steps > 0: + train_steps = opt.train_steps + if opt.single_pass and train_steps > 0: if device_context.is_master(): logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 @@ -215,9 +214,9 @@ def _train_iter(): trainer.train( train_iter, train_steps, - save_checkpoint_steps=opts.save_checkpoint_steps, + save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, - valid_steps=opts.valid_steps, + valid_steps=opt.valid_steps, device_context=device_context, ) diff --git a/mammoth/trainer.py b/onmt/trainer.py similarity index 86% rename from mammoth/trainer.py rename to onmt/trainer.py index 4f6527e9..0b76a3a3 100644 --- a/mammoth/trainer.py +++ b/onmt/trainer.py @@ -10,14 +10,14 @@ """ -import mammoth.distributed +import onmt.utils import torch import torch.distributed import torch.nn as nn import traceback from itertools import islice -from mammoth.utils.logging import logger +from onmt.utils.logging import logger def iter_on_device(iterator, device_context): @@ -30,7 +30,7 @@ def iter_on_device(iterator, device_context): def build_trainer( - opts, + opt, device_context, model, vocabs_dict, @@ -40,16 +40,16 @@ def build_trainer( generators_md=None, ): """ - Simplify `Trainer` creation based on user `opts`s* + Simplify `Trainer` creation based on user `opt`s* Args: - opts (:obj:`Namespace`): user options (usually from argument parsing) - model (:obj:`mammoth.models.NMTModel`): the model to train + opt (:obj:`Namespace`): user options (usually from argument parsing) + model (:obj:`onmt.models.NMTModel`): the model to train vocabs_dict (dict): dict of vocabs - optim (:obj:`mammoth.utils.Optimizer`): optimizer used during training + optim (:obj:`onmt.utils.Optimizer`): optimizer used during training data_type (str): string describing the type of data e.g. "text" - model_saver(:obj:`mammoth.models.ModelSaverBase`): the utility object + model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object used to save the model """ @@ -61,32 +61,32 @@ def build_trainer( generator = generators_md[f'generator_{lang}'] train_loss_md.add_module( f'trainloss{lang}', - mammoth.utils.loss.build_loss_compute(model, tgt_vocab, opts, train=True, generator=generator), + onmt.utils.loss.build_loss_compute(model, tgt_vocab, opt, train=True, generator=generator), ) valid_loss_md.add_module( f'valloss{lang}', - mammoth.utils.loss.build_loss_compute(model, tgt_vocab, opts, train=False, generator=generator), + onmt.utils.loss.build_loss_compute(model, tgt_vocab, opt, train=False, generator=generator), ) - trunc_size = opts.truncated_decoder # Badly named... - shard_size = opts.max_generator_batches if opts.model_dtype == 'fp32' else 0 - norm_method = opts.normalization - accum_count = opts.accum_count - accum_steps = opts.accum_steps - average_decay = opts.average_decay - average_every = opts.average_every - dropout = opts.dropout - dropout_steps = opts.dropout_steps - gpu_verbose_level = opts.gpu_verbose_level + trunc_size = opt.truncated_decoder # Badly named... + shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0 + norm_method = opt.normalization + accum_count = opt.accum_count + accum_steps = opt.accum_steps + average_decay = opt.average_decay + average_every = opt.average_every + dropout = opt.dropout + dropout_steps = opt.dropout_steps + gpu_verbose_level = opt.gpu_verbose_level earlystopper = ( - mammoth.utils.EarlyStopping(opts.early_stopping, scorers=mammoth.utils.scorers_from_opts(opts)) - if opts.early_stopping > 0 + onmt.utils.EarlyStopping(opt.early_stopping, scorers=onmt.utils.scorers_from_opts(opt)) + if opt.early_stopping > 0 else None ) - report_manager = mammoth.utils.build_report_manager(opts, device_context.node_rank, device_context.local_rank) - trainer = mammoth.Trainer( + report_manager = onmt.utils.build_report_manager(opt, device_context.node_rank, device_context.local_rank) + trainer = onmt.Trainer( model, train_loss_md, valid_loss_md, @@ -99,16 +99,16 @@ def build_trainer( device_context=device_context, gpu_verbose_level=gpu_verbose_level, report_manager=report_manager, - with_align=True if opts.lambda_align > 0 else False, + with_align=True if opt.lambda_align > 0 else False, model_saver=model_saver, average_decay=average_decay, average_every=average_every, - model_dtype=opts.model_dtype, + model_dtype=opt.model_dtype, earlystopper=earlystopper, dropout=dropout, dropout_steps=dropout_steps, task_queue_manager=task_queue_manager, - report_stats_from_parameters=opts.report_stats_from_parameters, + report_stats_from_parameters=opt.report_stats_from_parameters, ) return trainer @@ -118,13 +118,13 @@ class Trainer(object): Class that controls the training process. Args: - model(:py:class:`mammoth.models.model.NMTModel`): translation model + model(:py:class:`onmt.models.model.NMTModel`): translation model to train - train_loss(:obj:`mammoth.utils.loss.LossComputeBase`): + train_loss(:obj:`onmt.utils.loss.LossComputeBase`): training loss computation - valid_loss(:obj:`mammoth.utils.loss.LossComputeBase`): + valid_loss(:obj:`onmt.utils.loss.LossComputeBase`): training loss computation - optim(:obj:`mammoth.utils.optimizers.Optimizer`): + optim(:obj:`onmt.utils.optimizers.Optimizer`): the optimizer responsible for update trunc_size(int): length of truncated back propagation through time shard_size(int): compute loss in shards of this size for efficiency @@ -132,9 +132,9 @@ class Trainer(object): norm_method(string): normalization methods: [sents|tokens] accum_count(list): accumulate gradients this many times. accum_steps(list): steps for accum gradients changes. - report_manager(:obj:`mammoth.utils.ReportMgrBase`): + report_manager(:obj:`onmt.utils.ReportMgrBase`): the object that creates reports, or None - model_saver(:obj:`mammoth.models.ModelSaverBase`): the saver is + model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is used to save a checkpoint. Thus nothing will be saved if this parameter is None """ @@ -260,8 +260,8 @@ def train( else: logger.info('Start training loop and validate every %d steps...', valid_steps) - total_stats = mammoth.utils.Statistics() - report_stats = mammoth.utils.Statistics() + total_stats = onmt.utils.Statistics() + report_stats = onmt.utils.Statistics() self._start_report_manager(start_time=total_stats.start_time) self.optim.zero_grad() @@ -289,7 +289,7 @@ def train( in self.model.encoder.get_submodule(layer_stack_index, encoder_id).named_parameters() if 'embeddings' not in name and 'adapter' not in name ] - mammoth.distributed.only_ready_reduce_and_rescale_grads(params, group=group) + onmt.utils.distributed.only_ready_reduce_and_rescale_grads(params, group=group) for (layer_stack_index, decoder_id), (_, group) in self.my_decoder_groups.items(): params = [ @@ -297,17 +297,17 @@ def train( in self.model.decoder.get_submodule(layer_stack_index, decoder_id).named_parameters() if 'embeddings' not in name and 'adapter' not in name ] - mammoth.distributed.only_ready_reduce_and_rescale_grads(params, group=group) + onmt.utils.distributed.only_ready_reduce_and_rescale_grads(params, group=group) for (src_lang,), (_, group) in self.my_src_emb_groups.items(): embs = self.model.encoder.embeddings[f'embeddings_{src_lang}'] - mammoth.distributed.only_ready_reduce_and_rescale_grads(embs.named_parameters(), group=group) + onmt.utils.distributed.only_ready_reduce_and_rescale_grads(embs.named_parameters(), group=group) for (tgt_lang,), (_, group) in self.my_tgt_emb_groups.items(): embs = self.model.decoder.embeddings[f'embeddings_{tgt_lang}'] - mammoth.distributed.only_ready_reduce_and_rescale_grads(embs.named_parameters(), group=group) + onmt.utils.distributed.only_ready_reduce_and_rescale_grads(embs.named_parameters(), group=group) - mammoth.distributed.only_ready_reduce_and_rescale_grads( + onmt.utils.distributed.only_ready_reduce_and_rescale_grads( self.model.generator[f'generator_{tgt_lang}'].named_parameters(), group=group ) @@ -316,18 +316,18 @@ def train( adapter = self.model.encoder.get_submodule(layer_stack_index, encoder_id).get_adapter( adapter_group, sub_id ) - mammoth.distributed.only_ready_reduce_and_rescale_grads(adapter.named_parameters(), group=group) + onmt.utils.distributed.only_ready_reduce_and_rescale_grads(adapter.named_parameters(), group=group) for adapter_id, (_, group) in self.my_decoder_adapter_groups.items(): layer_stack_index, decoder_id, adapter_group, sub_id = adapter_id adapter = self.model.decoder.get_submodule(layer_stack_index, decoder_id).get_adapter( adapter_group, sub_id ) - mammoth.distributed.only_ready_reduce_and_rescale_grads(adapter.named_parameters(), group=group) + onmt.utils.distributed.only_ready_reduce_and_rescale_grads(adapter.named_parameters(), group=group) # a group is not specified: reduce across all devices if device_context.is_distributed(): - mammoth.distributed.only_ready_reduce_and_rescale_grads( + onmt.utils.distributed.only_ready_reduce_and_rescale_grads( self.model.attention_bridge.named_parameters() ) @@ -420,11 +420,11 @@ def validate(self, valid_iter, moving_average=None, task=None): # Tasks need not define validation paths: hence, a device need not contain # any validation path. This would cause statistics equals to 0 word seen, # which would then cause a zero devision when normalizing PPL per words. - stats = None # mammoth.utils.Statistics() + stats = None # onmt.utils.Statistics() for batch, metadata, _ in valid_iter: if stats is None: - stats = mammoth.utils.Statistics() + stats = onmt.utils.Statistics() src, src_lengths = batch.src if isinstance(batch.src, tuple) else (batch.src, None) tgt = batch.tgt @@ -542,14 +542,14 @@ def _maybe_gather_stats(self, stat): Gather statistics in multi-processes cases Args: - stat(:obj:mammoth.utils.Statistics): a Statistics object to gather + stat(:obj:onmt.utils.Statistics): a Statistics object to gather or None (it returns None in this case) Returns: stat: the updated (or unchanged) stat object """ if stat is not None and self.device_context.is_distributed(): - return mammoth.utils.Statistics.all_gather_stats(stat) + return onmt.utils.Statistics.all_gather_stats(stat) return stat def _maybe_update_stats_from_parameters(self, report_stats, named_parameters): @@ -559,7 +559,7 @@ def _maybe_update_stats_from_parameters(self, report_stats, named_parameters): def _maybe_report_training(self, step, num_steps, learning_rate, report_stats): """ Simple function to report training stats (if report_manager is set) - see `mammoth.utils.ReportManagerBase.report_training` for doc + see `onmt.utils.ReportManagerBase.report_training` for doc """ if self.report_manager is not None: return self.report_manager.report_training( @@ -574,7 +574,7 @@ def _maybe_report_training(self, step, num_steps, learning_rate, report_stats): def _report_step(self, learning_rate, step, train_stats=None, valid_stats=None): """ Simple function to report stats (if report_manager is set) - see `mammoth.utils.ReportManagerBase.report_step` for doc + see `onmt.utils.ReportManagerBase.report_step` for doc """ if self.report_manager is not None: return self.report_manager.report_step( diff --git a/mammoth/transforms/__init__.py b/onmt/transforms/__init__.py similarity index 95% rename from mammoth/transforms/__init__.py rename to onmt/transforms/__init__.py index b585e216..673f8383 100644 --- a/mammoth/transforms/__init__.py +++ b/onmt/transforms/__init__.py @@ -49,4 +49,4 @@ def register_transfrom_cls(cls): path = os.path.join(transform_dir, file) if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)): file_name = file[: file.find('.py')] if file.endswith('.py') else file - module = importlib.import_module('mammoth.transforms.' + file_name) + module = importlib.import_module('onmt.transforms.' + file_name) diff --git a/mammoth/transforms/denoising.py b/onmt/transforms/denoising.py similarity index 99% rename from mammoth/transforms/denoising.py rename to onmt/transforms/denoising.py index f1fc9693..36fa8c44 100644 --- a/mammoth/transforms/denoising.py +++ b/onmt/transforms/denoising.py @@ -4,8 +4,8 @@ import torch from typing import Sequence, Callable -from mammoth.constants import DefaultTokens, SubwordMarker -from mammoth.transforms import register_transform +from onmt.constants import DefaultTokens, SubwordMarker +from onmt.transforms import register_transform from .transform import Transform diff --git a/mammoth/transforms/features.py b/onmt/transforms/features.py similarity index 94% rename from mammoth/transforms/features.py rename to onmt/transforms/features.py index 762cdcd4..a6fd06c4 100644 --- a/mammoth/transforms/features.py +++ b/onmt/transforms/features.py @@ -1,7 +1,7 @@ -from mammoth.utils.logging import logger -from mammoth.transforms import register_transform +from onmt.utils.logging import logger +from onmt.transforms import register_transform from .transform import Transform -from mammoth.utils.alignment import subword_map_by_joiner, subword_map_by_spacer +from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer import re from collections import defaultdict diff --git a/mammoth/transforms/misc.py b/onmt/transforms/misc.py similarity index 96% rename from mammoth/transforms/misc.py rename to onmt/transforms/misc.py index b8c1e8b1..a7b8e1a0 100644 --- a/mammoth/transforms/misc.py +++ b/onmt/transforms/misc.py @@ -1,5 +1,5 @@ -from mammoth.utils.logging import logger -from mammoth.transforms import register_transform +from onmt.utils.logging import logger +from onmt.transforms import register_transform from .transform import Transform, ObservableStats @@ -72,7 +72,7 @@ def _get_prefix(corpus): def get_prefix_dict(cls, opts): """Get all needed prefix correspond to corpus in `opts`.""" prefix_dict = {} - for c_name, corpus in opts.tasks.items(): + for c_name, corpus in opts.data.items(): prefix = cls._get_prefix(corpus) if prefix is not None: logger.info(f"Get prefix for {c_name}: {prefix}") diff --git a/mammoth/transforms/sampling.py b/onmt/transforms/sampling.py similarity index 98% rename from mammoth/transforms/sampling.py rename to onmt/transforms/sampling.py index e2b55182..c0fadea9 100644 --- a/mammoth/transforms/sampling.py +++ b/onmt/transforms/sampling.py @@ -1,8 +1,8 @@ """Transforms relate to hamming distance sampling.""" import random import numpy as np -from mammoth.constants import DefaultTokens -from mammoth.transforms import register_transform +from onmt.constants import DefaultTokens +from onmt.transforms import register_transform from .transform import Transform, ObservableStats diff --git a/mammoth/transforms/tokenize.py b/onmt/transforms/tokenize.py similarity index 99% rename from mammoth/transforms/tokenize.py rename to onmt/transforms/tokenize.py index 5c6e283a..bf4470d3 100644 --- a/mammoth/transforms/tokenize.py +++ b/onmt/transforms/tokenize.py @@ -1,6 +1,6 @@ """Transforms relate to tokenization/subword.""" -from mammoth.utils.logging import logger -from mammoth.transforms import register_transform +from onmt.utils.logging import logger +from onmt.transforms import register_transform from .transform import Transform, ObservableStats diff --git a/mammoth/transforms/transform.py b/onmt/transforms/transform.py similarity index 99% rename from mammoth/transforms/transform.py rename to onmt/transforms/transform.py index e553f6a2..6238d3ae 100644 --- a/mammoth/transforms/transform.py +++ b/onmt/transforms/transform.py @@ -1,7 +1,7 @@ """Base Transform class and relate utils.""" import torch -from mammoth.utils.logging import logger -from mammoth.utils.misc import check_path +from onmt.utils.logging import logger +from onmt.utils.misc import check_path class Transform(object): diff --git a/onmt/translate/__init__.py b/onmt/translate/__init__.py new file mode 100644 index 00000000..21901092 --- /dev/null +++ b/onmt/translate/__init__.py @@ -0,0 +1,25 @@ +""" Modules for translation """ +from onmt.translate.translator import Translator, GeneratorLM +from onmt.translate.translation import Translation, TranslationBuilder +from onmt.translate.beam_search import BeamSearch, GNMTGlobalScorer +from onmt.translate.beam_search import BeamSearchLM +from onmt.translate.decode_strategy import DecodeStrategy +from onmt.translate.greedy_search import GreedySearch, GreedySearchLM +from onmt.translate.penalties import PenaltyBuilder +from onmt.translate.translation_server import TranslationServer, ServerModelError + +__all__ = [ + 'Translator', + 'Translation', + 'BeamSearch', + 'GNMTGlobalScorer', + 'TranslationBuilder', + 'PenaltyBuilder', + 'TranslationServer', + 'ServerModelError', + "DecodeStrategy", + "GreedySearch", + "GreedySearchLM", + "BeamSearchLM", + "GeneratorLM", +] diff --git a/mammoth/translate/beam_search.py b/onmt/translate/beam_search.py similarity index 98% rename from mammoth/translate/beam_search.py rename to onmt/translate/beam_search.py index c5741367..cb60c298 100644 --- a/mammoth/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -1,6 +1,6 @@ import torch -from mammoth.translate import penalties -from mammoth.translate.decode_strategy import DecodeStrategy +from onmt.translate import penalties +from onmt.translate.decode_strategy import DecodeStrategy import warnings @@ -22,7 +22,7 @@ class BeamSearchBase(DecodeStrategy): unk (int): See base. n_best (int): Don't stop until at least this many beams have reached EOS. - global_scorer (mammoth.translate.GNMTGlobalScorer): Scorer instance. + global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance. min_length (int): See base. max_length (int): See base. return_attention (bool): See base. @@ -410,8 +410,8 @@ class GNMTGlobalScorer(object): """ @classmethod - def from_opts(cls, opts): - return cls(opts.alpha, opts.beta, opts.length_penalty, opts.coverage_penalty) + def from_opt(cls, opt): + return cls(opt.alpha, opt.beta, opt.length_penalty, opt.coverage_penalty) def __init__(self, alpha, beta, length_penalty, coverage_penalty): self._validate(alpha, beta, length_penalty, coverage_penalty) diff --git a/mammoth/translate/decode_strategy.py b/onmt/translate/decode_strategy.py similarity index 99% rename from mammoth/translate/decode_strategy.py rename to onmt/translate/decode_strategy.py index cabf7539..0fd86906 100644 --- a/mammoth/translate/decode_strategy.py +++ b/onmt/translate/decode_strategy.py @@ -1,7 +1,7 @@ import torch from copy import deepcopy -from mammoth.utils.misc import tile +from onmt.utils.misc import tile class DecodeStrategy(object): diff --git a/mammoth/translate/greedy_search.py b/onmt/translate/greedy_search.py similarity index 96% rename from mammoth/translate/greedy_search.py rename to onmt/translate/greedy_search.py index 91251b32..1631d379 100644 --- a/mammoth/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F -from mammoth.translate.decode_strategy import DecodeStrategy +from onmt.translate.decode_strategy import DecodeStrategy def sample_topp(logits, keep_topp): @@ -98,7 +98,7 @@ class GreedySearch(DecodeStrategy): eos (int): See base. unk (int): See base. batch_size (int): See base. - global_scorer (mammoth.translate.GNMTGlobalScorer): Scorer instance. + global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance. min_length (int): See base. max_length (int): See base. ban_unk_token (Boolean): See base. @@ -107,11 +107,11 @@ class GreedySearch(DecodeStrategy): return_attention (bool): See base. max_length (int): See base. sampling_temp (float): See - :func:`~mammoth.translate.greedy_search.sample_with_temperature()`. + :func:`~onmt.translate.greedy_search.sample_with_temperature()`. keep_topk (int): See - :func:`~mammoth.translate.greedy_search.sample_with_temperature()`. + :func:`~onmt.translate.greedy_search.sample_with_temperature()`. keep_topp (float): See - :func:`~mammoth.translate.greedy_search.sample_with_temperature()`. + :func:`~onmt.translate.greedy_search.sample_with_temperature()`. beam_size (int): Number of beams to use. """ diff --git a/mammoth/translate/penalties.py b/onmt/translate/penalties.py similarity index 100% rename from mammoth/translate/penalties.py rename to onmt/translate/penalties.py diff --git a/mammoth/translate/process_zh.py b/onmt/translate/process_zh.py similarity index 100% rename from mammoth/translate/process_zh.py rename to onmt/translate/process_zh.py diff --git a/mammoth/translate/translation.py b/onmt/translate/translation.py similarity index 96% rename from mammoth/translate/translation.py rename to onmt/translate/translation.py index 3b985d00..8d2aebb9 100644 --- a/mammoth/translate/translation.py +++ b/onmt/translate/translation.py @@ -1,7 +1,7 @@ """ Translation main class """ import os -from mammoth.constants import DefaultTokens -from mammoth.utils.alignment import build_align_pharaoh +from onmt.constants import DefaultTokens +from onmt.utils.alignment import build_align_pharaoh # FIXME @@ -14,8 +14,8 @@ class TranslationBuilder(object): Problem in Neural Machine Translation" :cite:`Luong2015b` Args: - data (mammoth.inputters.ParallelCorpus): Data. - vocabs (dict[str, mammoth.inputters.Vocab]): data vocabs + data (onmt.inputters.ParallelCorpus): Data. + vocabs (dict[str, onmt.inputters.Vocab]): data vocabs n_best (int): number of translations produced replace_unk (bool): replace unknown words using attention has_tgt (bool): will the batch have gold targets diff --git a/mammoth/translate/translation_server.py b/onmt/translate/translation_server.py similarity index 89% rename from mammoth/translate/translation_server.py rename to onmt/translate/translation_server.py index be2e4043..7184637d 100644 --- a/mammoth/translate/translation_server.py +++ b/onmt/translate/translation_server.py @@ -10,18 +10,18 @@ import traceback import importlib import torch -import mammoth.opts +import onmt.opts from itertools import islice, zip_longest from copy import deepcopy -from mammoth.constants import DefaultTokens -from mammoth.utils.logging import init_logger -from mammoth.utils.misc import set_random_seed -from mammoth.utils.misc import check_model_config -from mammoth.utils.alignment import to_word_align -from mammoth.utils.parse import ArgumentParser -from mammoth.translate.translator import build_translator +from onmt.constants import DefaultTokens +from onmt.utils.logging import init_logger +from onmt.utils.misc import set_random_seed +from onmt.utils.misc import check_model_config +from onmt.utils.alignment import to_word_align +from onmt.utils.parse import ArgumentParser +from onmt.translate.translator import build_translator def critical(func): @@ -78,7 +78,7 @@ class ServerModelError(Exception): class CTranslate2Translator(object): """ This class wraps the ctranslate2.Translator object to - reproduce the mammoth.translate.translator API. + reproduce the onmt.translate.translator API. """ def __init__(self, model_path, ct2_translator_args, ct2_translate_batch_args, target_prefix=False, preload=False): @@ -95,7 +95,7 @@ def __init__(self, model_path, ct2_translator_args, ct2_translate_batch_args, ta self.translator.unload_model(to_cpu=True) @staticmethod - def convert_onmt_to_ct2_opts(ct2_translator_args, ct2_translate_batch_args, opts): + def convert_onmt_to_ct2_opts(ct2_translator_args, ct2_translate_batch_args, opt): def setdefault_if_exists_must_match(obj, name, value): if name in obj: assert value == obj[name], ( @@ -115,18 +115,18 @@ def setdefault_if_exists_must_match(obj, name, value): ct2_translator_args.setdefault(name, value) onmt_for_translator = { - "device": "cuda" if opts.cuda else "cpu", - "device_index": opts.gpu if opts.cuda else 0, + "device": "cuda" if opt.cuda else "cpu", + "device_index": opt.gpu if opt.cuda else 0, } for name, value in onmt_for_translator.items(): setdefault_if_exists_must_match(ct2_translator_args, name, value) onmt_for_translate_batch_enforce = { - "beam_size": opts.beam_size, - "max_batch_size": opts.batch_size, - "num_hypotheses": opts.n_best, - "max_decoding_length": opts.max_length, - "min_decoding_length": opts.min_length, + "beam_size": opt.beam_size, + "max_batch_size": opt.batch_size, + "num_hypotheses": opt.n_best, + "max_decoding_length": opt.max_length, + "min_decoding_length": opt.min_length, } for name, value in onmt_for_translate_batch_enforce.items(): setdefault_if_exists_must_match(ct2_translate_batch_args, name, value) @@ -191,32 +191,32 @@ def start(self, config_file): } kwargs = {k: v for (k, v) in kwargs.items() if v is not None} model_id = conf.get("id", None) - opts = conf["opts"] - opts["models"] = conf["models"] - self.preload_model(opts, model_id=model_id, **kwargs) + opt = conf["opt"] + opt["models"] = conf["models"] + self.preload_model(opt, model_id=model_id, **kwargs) - def clone_model(self, model_id, opts, timeout=-1): + def clone_model(self, model_id, opt, timeout=-1): """Clone a model `model_id`. - Different options may be passed. If `opts` is None, it will use the + Different options may be passed. If `opt` is None, it will use the same set of options """ if model_id in self.models: - if opts is None: - opts = self.models[model_id].user_opt - opts["models"] = self.models[model_id].opts.models - return self.load_model(opts, timeout) + if opt is None: + opt = self.models[model_id].user_opt + opt["models"] = self.models[model_id].opt.models + return self.load_model(opt, timeout) else: raise ServerModelError("No such model '%s'" % str(model_id)) - def load_model(self, opts, model_id=None, **model_kwargs): + def load_model(self, opt, model_id=None, **model_kwargs): """Load a model given a set of options""" - model_id = self.preload_model(opts, model_id=model_id, **model_kwargs) + model_id = self.preload_model(opt, model_id=model_id, **model_kwargs) load_time = self.models[model_id].load_time return model_id, load_time - def preload_model(self, opts, model_id=None, **model_kwargs): + def preload_model(self, opt, model_id=None, **model_kwargs): """Preloading the model: updating internal datastructure It will effectively load the model if `load` is set @@ -230,7 +230,7 @@ def preload_model(self, opts, model_id=None, **model_kwargs): model_id += 1 self.next_id = model_id + 1 print("Pre-loading model %d" % model_id) - model = ServerModel(opts, model_id, **model_kwargs) + model = ServerModel(opt, model_id, **model_kwargs) self.models[model_id] = model return model_id @@ -274,7 +274,7 @@ class ServerModel(object): """Wrap a model with server functionality. Args: - opts (dict): Options for the Translator + opt (dict): Options for the Translator model_id (int): Model ID preprocess_opt (list): Options for preprocess processus or None tokenizer_opt (dict): Options for the tokenizer or None @@ -292,7 +292,7 @@ class ServerModel(object): def __init__( self, - opts, + opt, model_id, preprocess_opt=None, tokenizer_opt=None, @@ -307,7 +307,7 @@ def __init__( ct2_translate_batch_args=None, ): self.model_root = model_root - self.opts = self.parse_opt(opts) + self.opt = self.parse_opt(opt) self.custom_opt = custom_opt self.model_id = model_id @@ -322,20 +322,20 @@ def __init__( self.ct2_translate_batch_args = ct2_translate_batch_args self.unload_timer = None - self.user_opt = opts + self.user_opt = opt self.tokenizers = None - if len(self.opts.log_file) > 0: - log_file = os.path.join(model_root, self.opts.log_file) + if len(self.opt.log_file) > 0: + log_file = os.path.join(model_root, self.opt.log_file) else: log_file = None - self.logger = init_logger(log_file=log_file, log_file_level=self.opts.log_file_level, rotate=True) + self.logger = init_logger(log_file=log_file, log_file_level=self.opt.log_file_level, rotate=True) self.loading_lock = threading.Event() self.loading_lock.set() self.running_lock = threading.Semaphore(value=1) - set_random_seed(self.opts.seed, self.opts.cuda) + set_random_seed(self.opt.seed, self.opt.cuda) if self.preprocess_opt is not None: self.logger.info("Loading preprocessor") @@ -370,28 +370,28 @@ def __init__( self.load(preload=True) self.stop_unload_timer() - def parse_opt(self, opts): - """Parse the option set passed by the user using `mammoth.opts` + def parse_opt(self, opt): + """Parse the option set passed by the user using `onmt.opts` Args: - opts (dict): Options passed by the user + opt (dict): Options passed by the user Returns: - opts (argparse.Namespace): full set of options for the Translator + opt (argparse.Namespace): full set of options for the Translator """ prec_argv = sys.argv sys.argv = sys.argv[:1] parser = ArgumentParser() - mammoth.opts.translate_opts(parser) + onmt.opts.translate_opts(parser) - models = opts['models'] + models = opt['models'] if not isinstance(models, (list, tuple)): models = [models] - opts['models'] = [os.path.join(self.model_root, model) for model in models] - opts['src'] = "dummy_src" + opt['models'] = [os.path.join(self.model_root, model) for model in models] + opt['src'] = "dummy_src" - for (k, v) in opts.items(): + for (k, v) in opt.items(): if k == 'models': sys.argv += ['-model'] sys.argv += [str(model) for model in v] @@ -400,12 +400,12 @@ def parse_opt(self, opts): else: sys.argv += ['-%s' % k, str(v)] - opts = parser.parse_args() - ArgumentParser.validate_translate_opts(opts) - opts.cuda = opts.gpu > -1 + opt = parser.parse_args() + ArgumentParser.validate_translate_opts(opt) + opt.cuda = opt.gpu > -1 sys.argv = prec_argv - return opts + return opt @property def loaded(self): @@ -421,18 +421,18 @@ def load(self, preload=False): try: if self.ct2_model is not None: CTranslate2Translator.convert_onmt_to_ct2_opts( - self.ct2_translator_args, self.ct2_translate_batch_args, self.opts + self.ct2_translator_args, self.ct2_translate_batch_args, self.opt ) self.translator = CTranslate2Translator( self.ct2_model, ct2_translator_args=self.ct2_translator_args, ct2_translate_batch_args=self.ct2_translate_batch_args, - target_prefix=self.opts.tgt_prefix, + target_prefix=self.opt.tgt_prefix, preload=preload, ) else: self.translator = build_translator( - self.opts, report_score=False, out_file=codecs.open(os.devnull, "w", "utf-8") + self.opt, report_score=False, out_file=codecs.open(os.devnull, "w", "utf-8") ) except RuntimeError as e: raise ServerModelError("Runtime Error: %s" % str(e)) @@ -470,7 +470,7 @@ def run(self, inputs): if not self.loaded: self.load() timer.tick(name="load") - elif self.opts.cuda: + elif self.opt.cuda: self.to_gpu() timer.tick(name="to_gpu") @@ -517,14 +517,14 @@ def run(self, inputs): scores, predictions = self.translator.translate( texts_to_translate, tgt=texts_ref, - batch_size=len(texts_to_translate) if self.opts.batch_size == 0 else self.opts.batch_size, + batch_size=len(texts_to_translate) if self.opt.batch_size == 0 else self.opt.batch_size, ) except (RuntimeError, Exception) as e: err = "Error: %s" % str(e) self.logger.error(err) self.logger.error("repr(text_to_translate): " + repr(texts_to_translate)) self.logger.error("model: #%s" % self.model_id) - self.logger.error("model opts: " + str(self.opts.__dict__)) + self.logger.error("model opt: " + str(self.opt.__dict__)) self.logger.error(traceback.format_exc()) raise ServerModelError(err) @@ -541,7 +541,7 @@ def run(self, inputs): def flatten_list(_list): return sum(_list, []) - tiled_texts = [t for t in texts_to_translate for _ in range(self.opts.n_best)] + tiled_texts = [t for t in texts_to_translate for _ in range(self.opt.n_best)] results = flatten_list(predictions) def maybe_item(x): @@ -556,24 +556,24 @@ def maybe_item(x): # build back results with empty texts for i in empty_indices: - j = i * self.opts.n_best - results = results[:j] + [""] * self.opts.n_best + results[j:] - aligns = aligns[:j] + [None] * self.opts.n_best + aligns[j:] - scores = scores[:j] + [0] * self.opts.n_best + scores[j:] + j = i * self.opt.n_best + results = results[:j] + [""] * self.opt.n_best + results[j:] + aligns = aligns[:j] + [None] * self.opt.n_best + aligns[j:] + scores = scores[:j] + [0] * self.opt.n_best + scores[j:] rebuilt_segs, scores, aligns = self.rebuild_seg_packages( - all_preprocessed, results, scores, aligns, self.opts.n_best + all_preprocessed, results, scores, aligns, self.opt.n_best ) results = [self.maybe_postprocess(seg) for seg in rebuilt_segs] - head_spaces = [h for h in head_spaces for i in range(self.opts.n_best)] - tail_spaces = [h for h in tail_spaces for i in range(self.opts.n_best)] + head_spaces = [h for h in head_spaces for i in range(self.opt.n_best)] + tail_spaces = [h for h in tail_spaces for i in range(self.opt.n_best)] results = ["".join(items) for items in zip(head_spaces, results, tail_spaces)] self.logger.info("Translation Results: %d", len(results)) - return results, scores, self.opts.n_best, timer.times, aligns + return results, scores, self.opt.n_best, timer.times, aligns def rebuild_seg_packages(self, all_preprocessed, results, scores, aligns, n_best): """ @@ -618,7 +618,7 @@ def do_timeout(self): def unload(self): self.logger.info("Unloading model %d" % self.model_id) del self.translator - if self.opts.cuda: + if self.opt.cuda: torch.cuda.empty_cache() self.stop_unload_timer() self.unload_timer = None @@ -639,7 +639,7 @@ def to_dict(self): hide_opt = ["models", "src"] d = { "model_id": self.model_id, - "opts": {k: self.user_opt[k] for k in self.user_opt.keys() if k not in hide_opt}, + "opt": {k: self.user_opt[k] for k in self.user_opt.keys() if k not in hide_opt}, "models": self.user_opt["models"], "loaded": self.loaded, "timeout": self.timeout, @@ -655,7 +655,7 @@ def to_cpu(self): self.translator.to_cpu() else: self.translator.model.cpu() - if self.opts.cuda: + if self.opt.cuda: torch.cuda.empty_cache() def to_gpu(self): @@ -663,7 +663,7 @@ def to_gpu(self): if type(self.translator) == CTranslate2Translator: self.translator.to_gpu() else: - torch.cuda.set_device(self.opts.gpu) + torch.cuda.set_device(self.opt.gpu) self.translator.model.cuda() def maybe_preprocess(self, sequence): @@ -785,7 +785,7 @@ def maybe_detokenize_with_align(self, sequence, src, side='tgt'): sorted or None if no alignment in output. """ align = None - if self.opts.report_align: + if self.opt.report_align: # output contain alignment sequence, align = sequence.split(DefaultTokens.ALIGNMENT_SEPARATOR) if align != '': diff --git a/mammoth/translate/translator.py b/onmt/translate/translator.py similarity index 88% rename from mammoth/translate/translator.py rename to onmt/translate/translator.py index 8388ec22..44c9092c 100644 --- a/mammoth/translate/translator.py +++ b/onmt/translate/translator.py @@ -8,58 +8,58 @@ import torch -import mammoth.model_builder -import mammoth.modules.decoder_ensemble -# from mammoth.inputters.text_dataset import InferenceDataIterator -from mammoth.translate.beam_search import BeamSearch, BeamSearchLM -from mammoth.translate.greedy_search import GreedySearch, GreedySearchLM -from mammoth.utils.misc import tile, set_random_seed, report_matrix -from mammoth.utils.alignment import extract_alignment, build_align_pharaoh -from mammoth.constants import ModelTask, DefaultTokens -from mammoth.inputters.dataset import ParallelCorpus -from mammoth.inputters.dataloader import build_dataloader - - -def build_translator(opts, task, report_score=True, logger=None, out_file=None): +import onmt.model_builder +import onmt.decoders.ensemble +# from onmt.inputters.text_dataset import InferenceDataIterator +from onmt.translate.beam_search import BeamSearch, BeamSearchLM +from onmt.translate.greedy_search import GreedySearch, GreedySearchLM +from onmt.utils.misc import tile, set_random_seed, report_matrix +from onmt.utils.alignment import extract_alignment, build_align_pharaoh +from onmt.modules.copy_generator import collapse_copy_scores +from onmt.constants import ModelTask, DefaultTokens +from onmt.inputters.dataset import ParallelCorpus +from onmt.inputters.dataloader import build_dataloader + + +def build_translator(opt, task, report_score=True, logger=None, out_file=None): if out_file is None: - outdir = os.path.dirname(opts.output) + outdir = os.path.dirname(opt.output) if outdir and not os.path.isdir(outdir): # FIXME use warnings instead logger.info('WARNING: output file directory does not exist... creating it.') - os.makedirs(os.path.dirname(opts.output), exist_ok=True) - out_file = codecs.open(opts.output, "w+", "utf-8") + os.makedirs(os.path.dirname(opt.output), exist_ok=True) + out_file = codecs.open(opt.output, "w+", "utf-8") load_test_model = ( - mammoth.modules.decoder_ensemble.load_test_model if len(opts.models) > 3 - else mammoth.model_builder.load_test_multitask_model + onmt.decoders.ensemble.load_test_model if len(opt.models) > 3 else onmt.model_builder.load_test_multitask_model ) if logger: logger.info(str(task)) - vocabs, model, model_opts = load_test_model(opts) + vocabs, model, model_opt = load_test_model(opt) - scorer = mammoth.translate.GNMTGlobalScorer.from_opts(opts) + scorer = onmt.translate.GNMTGlobalScorer.from_opt(opt) - if model_opts.model_task == ModelTask.LANGUAGE_MODEL: - translator = GeneratorLM.from_opts( + if model_opt.model_task == ModelTask.LANGUAGE_MODEL: + translator = GeneratorLM.from_opt( model, vocabs, - opts, - model_opts, + opt, + model_opt, global_scorer=scorer, out_file=out_file, - report_align=opts.report_align, + report_align=opt.report_align, report_score=report_score, logger=logger, ) else: - translator = Translator.from_opts( + translator = Translator.from_opt( model, vocabs, - opts, - model_opts, + opt, + model_opt, global_scorer=scorer, out_file=out_file, - report_align=opts.report_align, + report_align=opt.report_align, report_score=report_score, logger=logger, task=task, @@ -90,36 +90,36 @@ class Inference(object): """Translate a batch of sentences with a saved model. Args: - model (mammoth.modules.NMTModel): NMT model to use for translation - vocabs (dict[str, mammoth.inputters.Vocab]): A dict + model (onmt.modules.NMTModel): NMT model to use for translation + vocabs (dict[str, onmt.inputters.Vocab]): A dict mapping each side to its Vocab. src_file_path (str): Source file to read. tgt_reader (src): Target file, if necessary. gpu (int): GPU device. Set to negative for no GPU. n_best (int): How many beams to wait for. min_length (int): See - :class:`mammoth.translate.decode_strategy.DecodeStrategy`. + :class:`onmt.translate.decode_strategy.DecodeStrategy`. max_length (int): See - :class:`mammoth.translate.decode_strategy.DecodeStrategy`. + :class:`onmt.translate.decode_strategy.DecodeStrategy`. beam_size (int): Number of beams. random_sampling_topk (int): See - :class:`mammoth.translate.greedy_search.GreedySearch`. + :class:`onmt.translate.greedy_search.GreedySearch`. random_sampling_temp (float): See - :class:`mammoth.translate.greedy_search.GreedySearch`. + :class:`onmt.translate.greedy_search.GreedySearch`. stepwise_penalty (bool): Whether coverage penalty is applied every step or not. dump_beam (bool): Debugging option. block_ngram_repeat (int): See - :class:`mammoth.translate.decode_strategy.DecodeStrategy`. + :class:`onmt.translate.decode_strategy.DecodeStrategy`. ignore_when_blocking (set or frozenset): See - :class:`mammoth.translate.decode_strategy.DecodeStrategy`. + :class:`onmt.translate.decode_strategy.DecodeStrategy`. replace_unk (bool): Replace unknown token. tgt_prefix (bool): Force the predictions begin with provided -tgt. data_type (str): Source data type. verbose (bool): Print/log every translation. report_time (bool): Print/log total time/frequency. copy_attn (bool): Use copy attention. - global_scorer (mammoth.translate.GNMTGlobalScorer): Translation + global_scorer (onmt.translate.GNMTGlobalScorer): Translation scoring/reranking object. out_file (TextIO or codecs.StreamReaderWriter): Output file. report_score (bool) : Whether to report scores @@ -236,12 +236,12 @@ def __init__( set_random_seed(seed, self._use_cuda) @classmethod - def from_opts( + def from_opt( cls, model, vocabs, - opts, - model_opts, + opt, + model_opt, global_scorer=None, out_file=None, report_align=False, @@ -252,13 +252,13 @@ def from_opts( """Alternate constructor. Args: - model (mammoth.modules.NMTModel): See :func:`__init__()`. - vocabs (dict[str, mammoth.inputters.Vocab]): See + model (onmt.modules.NMTModel): See :func:`__init__()`. + vocabs (dict[str, onmt.inputters.Vocab]): See :func:`__init__()`. - opts (argparse.Namespace): Command line options - model_opts (argparse.Namespace): Command line options saved with + opt (argparse.Namespace): Command line options + model_opt (argparse.Namespace): Command line options saved with the model checkpoint. - global_scorer (mammoth.translate.GNMTGlobalScorer): See + global_scorer (onmt.translate.GNMTGlobalScorer): See :func:`__init__()`.. out_file (TextIO or codecs.StreamReaderWriter): See :func:`__init__()`. @@ -268,40 +268,40 @@ def from_opts( """ assert task is not None # TODO: maybe add dynamic part - cls.validate_task(model_opts.model_task) + cls.validate_task(model_opt.model_task) return cls( model, vocabs, - opts.src, - tgt_file_path=opts.tgt, - gpu=opts.gpu, - n_best=opts.n_best, - min_length=opts.min_length, - max_length=opts.max_length, - ratio=opts.ratio, - beam_size=opts.beam_size, - random_sampling_topk=opts.random_sampling_topk, - random_sampling_topp=opts.random_sampling_topp, - random_sampling_temp=opts.random_sampling_temp, - stepwise_penalty=opts.stepwise_penalty, - dump_beam=opts.dump_beam, - block_ngram_repeat=opts.block_ngram_repeat, - ignore_when_blocking=set(opts.ignore_when_blocking), - replace_unk=opts.replace_unk, - ban_unk_token=opts.ban_unk_token, - tgt_prefix=opts.tgt_prefix, - phrase_table=opts.phrase_table, - data_type=opts.data_type, - verbose=opts.verbose, - report_time=opts.report_time, - copy_attn=model_opts.copy_attn, + opt.src, + tgt_file_path=opt.tgt, + gpu=opt.gpu, + n_best=opt.n_best, + min_length=opt.min_length, + max_length=opt.max_length, + ratio=opt.ratio, + beam_size=opt.beam_size, + random_sampling_topk=opt.random_sampling_topk, + random_sampling_topp=opt.random_sampling_topp, + random_sampling_temp=opt.random_sampling_temp, + stepwise_penalty=opt.stepwise_penalty, + dump_beam=opt.dump_beam, + block_ngram_repeat=opt.block_ngram_repeat, + ignore_when_blocking=set(opt.ignore_when_blocking), + replace_unk=opt.replace_unk, + ban_unk_token=opt.ban_unk_token, + tgt_prefix=opt.tgt_prefix, + phrase_table=opt.phrase_table, + data_type=opt.data_type, + verbose=opt.verbose, + report_time=opt.report_time, + copy_attn=model_opt.copy_attn, global_scorer=global_scorer, out_file=out_file, report_align=report_align, report_score=report_score, logger=logger, - seed=opts.seed, + seed=opt.seed, task=task, ) @@ -499,7 +499,7 @@ def _translate( # ) # data_iter = None - xlation_builder = mammoth.translate.TranslationBuilder( + xlation_builder = onmt.translate.TranslationBuilder( corpus, self.vocabs, self.n_best, @@ -669,14 +669,39 @@ def _decode_and_generate( ) # Generator forward. - if "std" in dec_attn: - attn = dec_attn["std"] + if not self.copy_attn: + if "std" in dec_attn: + attn = dec_attn["std"] + else: + attn = None + log_probs = self.model.generator[f"generator_{self.task.tgt_lang}"](dec_out.squeeze(0)) + # returns [(batch_size x beam_size) , vocab ] when 1 step + # or [ tgt_len, batch_size, vocab ] when full sentence else: - attn = None - log_probs = self.model.generator[f"generator_{self.task.tgt_lang}"](dec_out.squeeze(0)) - # returns [(batch_size x beam_size) , vocab ] when 1 step - # or [ tgt_len, batch_size, vocab ] when full sentence - + attn = dec_attn["copy"] + scores = self.model.generator[f"generator_{self.task.tgt_lang}"]( + dec_out.view(-1, dec_out.size(2)), + attn.view(-1, attn.size(2)), + src_map, + ) + # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] + if batch_offset is None: + scores = scores.view(-1, batch.batch_size, scores.size(-1)) + scores = scores.transpose(0, 1).contiguous() + else: + scores = scores.view(-1, self.beam_size, scores.size(-1)) + scores = collapse_copy_scores( + scores, + batch, + self._tgt_vocab, + src_vocabs, + batch_dim=0, + batch_offset=batch_offset, + ) + scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) + log_probs = scores.squeeze(0).log() + # returns [(batch_size x beam_size) , vocab ] when 1 step + # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn def translate_batch(self, batch, src_vocabs, attn_debug): diff --git a/onmt/utils/__init__.py b/onmt/utils/__init__.py new file mode 100644 index 00000000..e835836e --- /dev/null +++ b/onmt/utils/__init__.py @@ -0,0 +1,23 @@ +"""Module defining various utilities.""" +from onmt.utils.misc import split_corpus, aeq, use_gpu, set_random_seed +from onmt.utils.alignment import make_batch_align_matrix +from onmt.utils.report_manager import ReportMgr, build_report_manager +from onmt.utils.statistics import Statistics +from onmt.utils.optimizers import MultipleOptimizer, Optimizer, AdaFactorFairSeq +from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts + +__all__ = [ + "split_corpus", + "aeq", + "use_gpu", + "set_random_seed", + "ReportMgr", + "build_report_manager", + "Statistics", + "MultipleOptimizer", + "Optimizer", + "AdaFactorFairSeq", + "EarlyStopping", + "scorers_from_opts", + "make_batch_align_matrix", +] diff --git a/mammoth/utils/alignment.py b/onmt/utils/alignment.py similarity index 99% rename from mammoth/utils/alignment.py rename to onmt/utils/alignment.py index b7a1e6b4..f58761b9 100644 --- a/mammoth/utils/alignment.py +++ b/onmt/utils/alignment.py @@ -2,7 +2,7 @@ import torch from itertools import accumulate -from mammoth.constants import SubwordMarker +from onmt.constants import SubwordMarker def make_batch_align_matrix(index_tensor, size=None, normalize=False): diff --git a/onmt/utils/cnn_factory.py b/onmt/utils/cnn_factory.py new file mode 100644 index 00000000..68430426 --- /dev/null +++ b/onmt/utils/cnn_factory.py @@ -0,0 +1,52 @@ +""" +Implementation of "Convolutional Sequence to Sequence Learning" +""" +import torch +import torch.nn as nn +import torch.nn.init as init + +import onmt.modules + +SCALE_WEIGHT = 0.5**0.5 + + +def shape_transform(x): + """Tranform the size of the tensors to fit for conv input.""" + return torch.unsqueeze(torch.transpose(x, 1, 2), 3) + + +class GatedConv(nn.Module): + """Gated convolution for CNN class""" + + def __init__(self, input_size, width=3, dropout=0.2, nopad=False): + super(GatedConv, self).__init__() + self.conv = onmt.modules.WeightNormConv2d( + input_size, 2 * input_size, kernel_size=(width, 1), stride=(1, 1), padding=(width // 2 * (1 - nopad), 0) + ) + init.xavier_uniform_(self.conv.weight, gain=(4 * (1 - dropout)) ** 0.5) + self.dropout = nn.Dropout(dropout) + + def forward(self, x_var): + x_var = self.dropout(x_var) + x_var = self.conv(x_var) + out, gate = x_var.split(int(x_var.size(1) / 2), 1) + out = out * torch.sigmoid(gate) + return out + + +class StackedCNN(nn.Module): + """Stacked CNN class""" + + def __init__(self, num_layers, input_size, cnn_kernel_width=3, dropout=0.2): + super(StackedCNN, self).__init__() + self.dropout = dropout + self.num_layers = num_layers + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append(GatedConv(input_size, cnn_kernel_width, dropout)) + + def forward(self, x): + for conv in self.layers: + x = x + conv(x) + x *= SCALE_WEIGHT + return x diff --git a/mammoth/distributed/tasks.py b/onmt/utils/distributed.py similarity index 58% rename from mammoth/distributed/tasks.py rename to onmt/utils/distributed.py index d6023583..bb9f6263 100644 --- a/mammoth/distributed/tasks.py +++ b/onmt/utils/distributed.py @@ -1,31 +1,409 @@ -"""sub-module defining tasks, task specifications and task management objects.""" +""" Pytorch Distributed utils + This piece of code was heavily inspired by the equivalent of Fairseq-py + https://github.com/pytorch/fairseq +""" +import math +import numpy as np +import os +import pickle +import signal +import torch.distributed + from abc import ABC, abstractmethod from argparse import Namespace from collections import OrderedDict, namedtuple, Counter from dataclasses import dataclass +from enum import Enum from itertools import cycle, islice from pprint import pformat from typing import Any, Optional, List -import numpy as np -import torch -import torch.distributed +from onmt.utils.logging import init_logger, logger +from onmt.utils.misc import set_random_seed -from mammoth.distributed.contexts import DeviceContext, WorldContext -from mammoth.utils.logging import logger +class DeviceContextEnum(Enum): + CPU = 1 + SINGLE_GPU = 2 + MULTI_GPU = 3 -class TaskDistributionStrategy(ABC): + +@dataclass +class WorldContext: + context: DeviceContextEnum + # Size of the world: total number of nodes, gpus on each node + n_nodes: int + gpus_per_node: int + + @property + def world_size(self): + """Total number of training GPUs""" + return self.n_nodes * self.gpus_per_node + + def is_distributed(self): + """When training is distributed over several devices, + multiprocessing is used to communicate gradients""" + return self.context == DeviceContextEnum.MULTI_GPU + + def is_gpu(self): + """Data tensors must be moved to the GPU for compute""" + return self.context != DeviceContextEnum.CPU + + def is_master(self): + """For code that should only run in one process: + - saving fully shared modules from one device only + - avoiding log spam when all devices would log the same result + """ + return not self.is_distributed() or self.global_rank == 0 + + def global_to_local(self, node_rank, local_rank): + assert node_rank is not None + assert local_rank is not None + return DeviceContext( + context=self.context, + n_nodes=self.n_nodes, + gpus_per_node=self.gpus_per_node, + node_rank=node_rank, + local_rank=local_rank, + ) + + @classmethod + def from_opt(cls, opt): + gpus_per_node = len(opt.gpu_ranks) + world_size = int(opt.world_size) if gpus_per_node > 0 else 0 + multinode = gpus_per_node != world_size + if world_size <= 0: + # setting a non-positive world size means use CPU + device_context_enum = DeviceContextEnum.CPU + if opt.n_nodes != 1: + raise ValueError('CPU training is only possible on a single node') + elif world_size == 1: + # world size 1 uses GPU, but is not distributed + device_context_enum = DeviceContextEnum.SINGLE_GPU + if opt.n_nodes != 1: + raise ValueError( + f'Invalid single-gpu node configuration: ' + f'n_nodes {opt.n_nodes} gpus_per_node {gpus_per_node} world_size {world_size}' + ) + else: + # world size > 1 + if multinode and opt.n_nodes == 1: + raise ValueError( + f'Invalid multi-node configuration: ' + f'n_nodes {opt.n_nodes} gpus_per_node {gpus_per_node} world_size {world_size}' + ) + device_context_enum = DeviceContextEnum.MULTI_GPU + world_context = WorldContext(context=device_context_enum, n_nodes=opt.n_nodes, gpus_per_node=gpus_per_node) + return world_context + + +@dataclass +class DeviceContext(WorldContext): + # Our place in the world + node_rank: int + local_rank: int + + @property + def global_rank(self) -> int: + return self.gpus_per_node * self.node_rank + self.local_rank + + @property + def id(self) -> str: + if self.is_gpu(): + return f'GPU {self.node_rank}:{self.local_rank}' + else: + return 'CPU' + + def validate(self, world_context): + # check that this DeviceContext is consistent with given WorldContext + assert self.context == world_context.context + assert self.n_nodes == world_context.n_nodes + assert self.gpus_per_node == world_context.gpus_per_node + # check that ranks are within the specified size of the world + assert 0 <= self.node_rank < self.n_nodes + if self.is_gpu(): + assert 0 <= self.local_rank < self.gpus_per_node + + +def multi_init(opt, global_rank): + dist_init_method = 'tcp://{master_ip}:{master_port}'.format(master_ip=opt.master_ip, master_port=opt.master_port) + + dist_world_size = opt.world_size + torch.distributed.init_process_group( + backend=opt.gpu_backend, + init_method=dist_init_method, + rank=global_rank, + world_size=dist_world_size, + ) + + gpu_rank = torch.distributed.get_rank() + + return gpu_rank + + +def broadcast_tensors(tensors, src=0, group=None): + for t in tensors: + if group is None: + torch.distributed.broadcast(t, src) + else: + torch.distributed.broadcast(t, src, group=group) + + +def only_ready_reduce_and_rescale_grads(named_parameters, group=None): + """ + Gradient synch tolerant to missing grads. + + Missing grads occur when some parameters are not trained between two + gradient synchs, e.g. the embeddings of a low-resource language with low + sampling weight. + + The algorithm first uses the 'has_grad' attribute set by the forward hook + 'has_grad_hook'. This hook ensures that all parameters of the modules + selected for use during the current training computation have 'has_grad' + set to True. This gives the list of parameters that have been trained on + this device ("ready"). + + A bit mask covering the parameters that are ready on this device is + communicated to the other devices in the group. The bit masks are reduced + using summation. The sum gives the number of real gradients for that + parameter, and can be used for normalization. + + If a parameter is ready on any device, all devices communicate a value. + Devices on which the parameter is ready communicate the actual gradient, + while devices on which it is not ready communicate a dummy zero tensor + instead. The sum computed previously is used for normalization. + + Args: + named_parameters: tuples of (str, Parameter) defining the parameters to consider + group: torch.distributed communication group """ - An abstract task distribution strategy, controls which task will be scheduled next. + # Set missing gradients to zero, keeping track of true gradients + require_grad = [(name, p) for (name, p) in named_parameters if p.requires_grad] + if not require_grad: + # Exit early if the component has no parameters that require a gradient + return + device = require_grad[0][1].device + ready_list = [] + for name, p in require_grad: + if hasattr(p, 'has_grad') and p.has_grad: + ready_list.append(1.0) + else: + ready_list.append(0.0) + if p.grad is None: + p.grad = torch.zeros_like(p) + + # Communicate the ready bits, and reduce them using summation. + # This gives the number of non-dummy gradients participating, for normalization + ready_t = torch.tensor(ready_list).to(device) + if group is None: + torch.distributed.all_reduce(ready_t) + else: + torch.distributed.all_reduce(ready_t, group=group) + rescale_denoms = ready_t # after reduction + + # Omit if all nodes sent a zero ready bit + denoms_mask = (rescale_denoms > 0).cpu() + params_with_grad = [p for ((name, p), m) in zip(require_grad, denoms_mask) if m] + grads = [p.grad.data for p in params_with_grad] + rescale_denoms = [denom for (denom, m) in zip(rescale_denoms, denoms_mask) if m] + assert len(grads) == len(rescale_denoms) + if len(grads) == 0: + return + + # If not, then set has_grad also on devices that did not train the parameter themselves. + # They now have a grad that they received from the other devices. + for name, p in require_grad: + p.has_grad = True + + # All devices communicate either a real gradient or a dummy zeros of the same size + # Can not use rescale_denom, as each grad may have its own denominator + all_reduce_and_rescale_tensors(grads, rescale_denom=1, group=group) + + # Normalize using the previously computed values + for grad, denom in zip(grads, rescale_denoms): + if denom > 1: + grad.div_(denom) + # Note: p.has_grad is reused in the optimizer to prevent the untrained components from being stepped + + +def all_reduce_and_rescale_tensors(tensors, rescale_denom, group=None, buffer_size=10485760): """ + All-reduce and rescale tensors in chunks of the specified size. + + Args: + tensors: list of Tensors to all-reduce + rescale_denom: denominator for rescaling summed Tensors + buffer_size: all-reduce chunk size in bytes + """ + # buffer size in bytes, determine equiv. # of elements based on data type + buffer_t = tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_() + buffer = [] + + def all_reduce_buffer(): + # copy tensors into buffer_t + offset = 0 + for t in buffer: + numel = t.numel() + buffer_t[offset:offset + numel].copy_(t.view(-1)) + offset += numel + + # all-reduce and rescale + if group is None: + torch.distributed.all_reduce(buffer_t[:offset]) + else: + torch.distributed.all_reduce(buffer_t[:offset], group=group) + buffer_t.div_(rescale_denom) + + # copy all-reduced buffer back into tensors + offset = 0 + for t in buffer: + numel = t.numel() + t.view(-1).copy_(buffer_t[offset:offset + numel]) + offset += numel + + filled = 0 + for t in tensors: + sz = t.numel() * t.element_size() + if sz > buffer_size: + # tensor is bigger than buffer, all-reduce and rescale directly + if group is None: + torch.distributed.all_reduce(t) + else: + torch.distributed.all_reduce(t, group=group) + t.div_(rescale_denom) + elif filled + sz > buffer_size: + # buffer is full, all-reduce and replace buffer with grad + all_reduce_buffer() + buffer = [t] + filled = sz + else: + # add tensor to buffer + buffer.append(t) + filled += sz + + if len(buffer) > 0: + all_reduce_buffer() + + +def all_gather_list(data, max_size=4096): + """Gathers arbitrary data from all nodes into a list.""" + world_size = torch.distributed.get_world_size() + if not hasattr(all_gather_list, '_in_buffer') or max_size != all_gather_list._in_buffer.size(): + all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) + all_gather_list._out_buffers = [torch.cuda.ByteTensor(max_size) for i in range(world_size)] + in_buffer = all_gather_list._in_buffer + out_buffers = all_gather_list._out_buffers + + enc = pickle.dumps(data) + enc_size = len(enc) + if enc_size + 2 > max_size: + raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2)) + assert max_size < 255 * 256 + in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k + in_buffer[1] = enc_size % 255 + in_buffer[2:enc_size + 2] = torch.ByteTensor(list(enc)) + + torch.distributed.all_gather(out_buffers, in_buffer.cuda()) + + results = [] + for i in range(world_size): + out_buffer = out_buffers[i] + size = (255 * out_buffer[0].item()) + out_buffer[1].item() + + bytes_list = bytes(out_buffer[2:size + 2].tolist()) + result = pickle.loads(bytes_list) + results.append(result) + return results + + +class ErrorHandler(object): + """A class that listens for exceptions in children processes and propagates + the tracebacks to the parent process.""" + + def __init__(self, error_queue): + """init error handler""" + import signal + import threading + + self.error_queue = error_queue + self.children_pids = [] + self.error_thread = threading.Thread(target=self.error_listener, daemon=True) + self.error_thread.start() + signal.signal(signal.SIGUSR1, self.signal_handler) + + def add_child(self, pid): + """error handler""" + self.children_pids.append(pid) + + def error_listener(self): + """error listener""" + (rank, original_trace) = self.error_queue.get() + self.error_queue.put((rank, original_trace)) + os.kill(os.getpid(), signal.SIGUSR1) + + def signal_handler(self, signalnum, stackframe): + """signal handler""" + for pid in self.children_pids: + os.kill(pid, signal.SIGINT) # kill children processes + (rank, original_trace) = self.error_queue.get() + msg = """\n\n-- Tracebacks above this line can probably + be ignored --\n\n""" + msg += original_trace + raise Exception(msg) + + +def batch_producer(generator_to_serve, queue, semaphore, opt, device_id): + """Produce batches to `queues` from `generator_to_serve`.""" + log_level = "INFO" if opt.verbose or device_id == 0 else "WARNING" + init_logger(opt.log_file, log_level=log_level) + set_random_seed(opt.seed, False) + logger.info("BATCH PRODUCER") + logger.info(generator_to_serve) + + for batch, metadata, communication_batch_id in generator_to_serve: + semaphore.acquire() + # Move batch to correspond device_id when consumer iterate + # hack to dodge unpicklable `dict_keys` + # batch.fields = list(batch.fields) + queue.put((batch, metadata, communication_batch_id)) + + +def consumer(process_fn, opt, device_context, error_queue, batch_queue, semaphore, task_queue_manager): + """Run `process_fn` on `device_id` with data from `batch_queue`.""" + try: + logger.info( + f'global_rank {device_context.global_rank} ' + f'node_rank {device_context.node_rank} ' + f'local_rank {device_context.local_rank}' + ) + logger.info(f'opt.gpu_ranks {opt.gpu_ranks}') + multi_init(opt, device_context.global_rank) + # error_queue not passed (is this intentional?) + process_fn( + opt, + device_context=device_context, + batch_queue=batch_queue, + semaphore=semaphore, + task_queue_manager=task_queue_manager, + ) + + except KeyboardInterrupt: + pass # killed by parent, do nothing + except Exception: + # propagate exception to parent process, keeping original traceback + import traceback + + error_queue.put((opt.gpu_ranks[device_context.node_rank], traceback.format_exc())) + + +class TaskDistributionStrategy(ABC): @abstractmethod def __init__(self, my_corpus_ids: List[str], **kwargs): pass @classmethod @abstractmethod - def from_opts(cls, my_corpus_ids: List[str], opts: dict): + def from_opt(cls, my_corpus_ids: List[str], opt: dict): pass @abstractmethod @@ -64,10 +442,10 @@ def __init__( raise ValueError('Invalid curriculum: no corpus is ready to start in the first step') @classmethod - def from_opts(cls, my_corpus_ids: List[str], opts: dict): - my_weights = [opts.tasks[corpus_id]['weight'] for corpus_id in my_corpus_ids] + def from_opt(cls, my_corpus_ids: List[str], opt: dict): + my_weights = [opt.data[corpus_id]['weight'] for corpus_id in my_corpus_ids] my_introduce_at_training_step = [ - opts.tasks[corpus_id]['introduce_at_training_step'] for corpus_id in my_corpus_ids + opt.data[corpus_id]['introduce_at_training_step'] for corpus_id in my_corpus_ids ] return cls(my_corpus_ids, my_weights, my_introduce_at_training_step) @@ -101,7 +479,7 @@ def __init__(self, my_corpus_ids: List[str]): self.infinite_corpus_ids = cycle(my_corpus_ids) @classmethod - def from_opts(cls, my_corpus_ids: List[str], opts: dict): + def from_opt(cls, my_corpus_ids: List[str], opt: dict): return cls(my_corpus_ids) def sample_corpus_ids( @@ -133,7 +511,7 @@ class TaskSpecs(): decoder_id: List[str] corpus_id: str weight: int - corpus_opts: dict + corpus_opt: dict src_vocab: Any # FIXME: type tgt_vocab: Any encoder_adapter_ids: List[str] @@ -156,11 +534,11 @@ def get_serializable_metadata(self): ) -def get_adapter_ids(opts, corpus_opts, side): - if 'adapters' not in opts or 'adapters' not in corpus_opts: +def get_adapter_ids(opt, corpus_opt, side): + if 'adapters' not in opt or 'adapters' not in corpus_opt: return [] - global_adapters_opt = opts.adapters.get(side, None) - corpus_adapter_opt = corpus_opts['adapters'].get(side, None) + global_adapters_opt = opt.adapters.get(side, None) + corpus_adapter_opt = corpus_opt['adapters'].get(side, None) if not global_adapters_opt or not corpus_adapter_opt: return [] result = [] @@ -229,17 +607,17 @@ def local_rank(self): return self.device_context.local_rank @classmethod - def from_opts(cls, opts: Namespace, world_context: WorldContext): - n_tasks = len(opts.tasks) + def from_opt(cls, opt: Namespace, world_context: WorldContext): + n_tasks = len(opt.data) # Sorting the keys, to ensure that tasks have a consistent order across devices. # This in turn ensures the order in which components are created from those tasks. - corpus_ids = sorted(opts.tasks.keys()) + corpus_ids = sorted(opt.data.keys()) if world_context.is_distributed(): - if any(task.get('node_gpu', None) is not None for task in opts.tasks.values()): + if any(task.get('node_gpu', None) is not None for task in opt.data.values()): node_gpu = [ - tuple(int(y) for y in opts.tasks[corpus_id]['node_gpu'].split(':', 1)) + tuple(int(y) for y in opt.data[corpus_id]['node_gpu'].split(':', 1)) for corpus_id in corpus_ids] else: # When --node_gpu is not set, assume an assigment that fills gpus in rank order @@ -248,24 +626,24 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext): node_gpu = [(0, 0)] * n_tasks enc_sharing_group = [ - opts.tasks[corpus_id].get('enc_sharing_group', None) for corpus_id in corpus_ids + opt.data[corpus_id].get('enc_sharing_group', None) for corpus_id in corpus_ids ] dec_sharing_group = [ - opts.tasks[corpus_id].get('dec_sharing_group', None) for corpus_id in corpus_ids + opt.data[corpus_id].get('dec_sharing_group', None) for corpus_id in corpus_ids ] if any(x is not None for x in enc_sharing_group): - assert all(len(enc_ids) == len(opts.enc_layers) for enc_ids in enc_sharing_group) + assert all(len(enc_ids) == len(opt.enc_layers) for enc_ids in enc_sharing_group) else: # if no encoder sharing groups are defined, # it is assumed that there is only one encoder stack and it is language specific - if not len(opts.enc_layers) == 1: + if not len(opt.enc_layers) == 1: raise Exception('With more than one encoder stack, you must explictly define enc_sharing_group') if any(x is not None for x in dec_sharing_group): - assert all(len(dec_ids) == len(opts.dec_layers) for dec_ids in dec_sharing_group) + assert all(len(dec_ids) == len(opt.dec_layers) for dec_ids in dec_sharing_group) else: # if no decoder sharing groups are defined, # it is assumed that there is only one decoder stack and it is language specific - if not len(opts.dec_layers) == 1: + if not len(opt.dec_layers) == 1: raise Exception('With more than one decoder stack, you must explictly define dec_sharing_group') tasks = [] @@ -277,14 +655,14 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext): node_gpu, corpus_ids ): - corpus_opts = opts.tasks[corpus_id] - src_lang, tgt_lang = corpus_opts['src_tgt'].split('-', 1) - encoder_id = corpus_opts.get('enc_sharing_group', [src_lang]) - decoder_id = corpus_opts.get('dec_sharing_group', [tgt_lang]) - weight = corpus_opts.get('weight', 1.0) - if 'adapters' in corpus_opts: - encoder_adapter_ids = get_adapter_ids(opts, corpus_opts, 'encoder') - decoder_adapter_ids = get_adapter_ids(opts, corpus_opts, 'decoder') + corpus_opt = opt.data[corpus_id] + src_lang, tgt_lang = corpus_opt['src_tgt'].split('-', 1) + encoder_id = corpus_opt.get('enc_sharing_group', [src_lang]) + decoder_id = corpus_opt.get('dec_sharing_group', [tgt_lang]) + weight = corpus_opt.get('weight', 1.0) + if 'adapters' in corpus_opt: + encoder_adapter_ids = get_adapter_ids(opt, corpus_opt, 'encoder') + decoder_adapter_ids = get_adapter_ids(opt, corpus_opt, 'decoder') uses_adapters = True else: encoder_adapter_ids = None @@ -298,7 +676,7 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext): decoder_id=decoder_id, corpus_id=corpus_id, weight=weight, - corpus_opts=corpus_opts, + corpus_opt=corpus_opt, src_vocab=None, tgt_vocab=None, encoder_adapter_ids=encoder_adapter_ids, @@ -308,14 +686,14 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext): return cls( tasks, world_context=world_context, - tasks_per_communication_batch=opts.accum_count, + tasks_per_communication_batch=opt.accum_count, uses_adapters=uses_adapters, ) - def global_to_local(self, node_rank, local_rank, opts): + def global_to_local(self, node_rank, local_rank, opt): assert node_rank is not None assert local_rank is not None - task_distribution_strategy = self._get_strategy(node_rank=node_rank, local_rank=local_rank, opts=opts) + task_distribution_strategy = self._get_strategy(node_rank=node_rank, local_rank=local_rank, opt=opt) device_context = self.world_context.global_to_local(node_rank, local_rank) return self.__class__( self.tasks, @@ -328,15 +706,15 @@ def global_to_local(self, node_rank, local_rank, opts): uses_adapters=self.uses_adapters, ) - def _get_strategy(self, node_rank, local_rank, opts): + def _get_strategy(self, node_rank, local_rank, opt): assert node_rank is not None assert local_rank is not None # Global TQM does not have a task distribution strategy, but the local ones do my_corpus_ids = [task.corpus_id for task in self._tasks_on_device(node_rank, local_rank)] try: - strategy = TASK_DISTRIBUTION_STRATEGIES[opts.task_distribution_strategy].from_opts( + strategy = TASK_DISTRIBUTION_STRATEGIES[opt.task_distribution_strategy].from_opt( my_corpus_ids=my_corpus_ids, - opts=opts, + opt=opt, ) return strategy except Exception as e: @@ -558,11 +936,11 @@ def get_fields(self, side: str, fields_dict): raise RuntimeError # FIXME: merge with below - def get_vocabularies(self, opts: Namespace, side: str): + def get_vocabularies(self, opt: Namespace, side: str): result = [] for task in self.get_tasks(): lang = self.src_lang if side == 'src' else self.tgt_lang - vocab_path = opts.__getattribute__(f'{side}_vocab')[lang] + vocab_path = opt.__getattribute__(f'{side}_vocab')[lang] result.append((lang, vocab_path)) return result diff --git a/mammoth/utils/earlystopping.py b/onmt/utils/earlystopping.py similarity index 96% rename from mammoth/utils/earlystopping.py rename to onmt/utils/earlystopping.py index 6d20c60f..4244cf72 100644 --- a/mammoth/utils/earlystopping.py +++ b/onmt/utils/earlystopping.py @@ -1,5 +1,5 @@ from enum import Enum -from mammoth.utils.logging import logger +from onmt.utils.logging import logger class PatienceEnum(Enum): @@ -63,12 +63,12 @@ def _caller(self, stats): SCORER_BUILDER = {"ppl": PPLScorer, "accuracy": AccuracyScorer} -def scorers_from_opts(opts): - if opts.early_stopping_criteria is None: +def scorers_from_opts(opt): + if opt.early_stopping_criteria is None: return DEFAULT_SCORERS else: scorers = [] - for criterion in set(opts.early_stopping_criteria): + for criterion in set(opt.early_stopping_criteria): assert criterion in SCORER_BUILDER.keys(), "Criterion {} not found".format(criterion) scorers.append(SCORER_BUILDER[criterion]()) return scorers diff --git a/mammoth/utils/logging.py b/onmt/utils/logging.py similarity index 100% rename from mammoth/utils/logging.py rename to onmt/utils/logging.py diff --git a/mammoth/utils/loss.py b/onmt/utils/loss.py similarity index 86% rename from mammoth/utils/loss.py rename to onmt/utils/loss.py index 5061325d..100e7eb6 100644 --- a/mammoth/utils/loss.py +++ b/onmt/utils/loss.py @@ -6,11 +6,13 @@ import torch.nn as nn import torch.nn.functional as F -import mammoth -from mammoth.constants import ModelTask, DefaultTokens +import onmt +from onmt.modules.sparse_losses import SparsemaxLoss +from onmt.modules.sparse_activations import LogSparsemax +from onmt.constants import ModelTask, DefaultTokens -def build_loss_compute(model, tgt_vocab, opts, train=True, generator=None): +def build_loss_compute(model, tgt_vocab, opt, train=True, generator=None): """ Returns a LossCompute subclass which wraps around an nn.Module subclass (such as nn.NLLLoss) which defines the loss criterion. The LossCompute @@ -19,60 +21,62 @@ def build_loss_compute(model, tgt_vocab, opts, train=True, generator=None): Currently, the NMTLossCompute class handles all loss computation except for when using a copy mechanism. """ - device = torch.device("cuda" if mammoth.utils.misc.use_gpu(opts) else "cpu") + device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") padding_idx = tgt_vocab.stoi[DefaultTokens.PAD] unk_idx = tgt_vocab.stoi[DefaultTokens.UNK] - if opts.lambda_coverage != 0: - assert opts.coverage_attn, "--coverage_attn needs to be set in order to use --lambda_coverage != 0" + if opt.lambda_coverage != 0: + assert opt.coverage_attn, "--coverage_attn needs to be set in order to use --lambda_coverage != 0" - if opts.copy_attn: - criterion = mammoth.modules.CopyGeneratorLoss( - len(tgt_vocab), opts.copy_attn_force, unk_index=unk_idx, ignore_index=padding_idx + if opt.copy_attn: + criterion = onmt.modules.CopyGeneratorLoss( + len(tgt_vocab), opt.copy_attn_force, unk_index=unk_idx, ignore_index=padding_idx ) - elif opts.label_smoothing > 0 and train: - criterion = LabelSmoothingLoss(opts.label_smoothing, len(tgt_vocab), ignore_index=padding_idx) + elif opt.label_smoothing > 0 and train: + criterion = LabelSmoothingLoss(opt.label_smoothing, len(tgt_vocab), ignore_index=padding_idx) + elif isinstance(generator[-1], LogSparsemax): # elif isinstance(model.generator[-1], LogSparsemax): + criterion = SparsemaxLoss(ignore_index=padding_idx, reduction='sum') else: criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum') # if the loss function operates on vectors of raw logits instead of # probabilities, only the first part of the generator needs to be - # passed to the NMTLossCompute. At the moment, there is no supported - # loss function of this kind. - use_raw_logits = False + # passed to the NMTLossCompute. At the moment, the only supported + # loss function of this kind is the sparsemax loss. + use_raw_logits = isinstance(criterion, SparsemaxLoss) loss_gen = ( generator[0] if use_raw_logits else generator ) # loss_gen = model.generator[0] if use_raw_logits else model.generator - if opts.copy_attn: - if opts.model_task == ModelTask.SEQ2SEQ: - compute = mammoth.modules.CopyGeneratorLossCompute( - criterion, loss_gen, tgt_vocab, opts.copy_loss_by_seqlength, lambda_coverage=opts.lambda_coverage + if opt.copy_attn: + if opt.model_task == ModelTask.SEQ2SEQ: + compute = onmt.modules.CopyGeneratorLossCompute( + criterion, loss_gen, tgt_vocab, opt.copy_loss_by_seqlength, lambda_coverage=opt.lambda_coverage ) - elif opts.model_task == ModelTask.LANGUAGE_MODEL: - compute = mammoth.modules.CopyGeneratorLMLossCompute( - criterion, loss_gen, tgt_vocab, opts.copy_loss_by_seqlength, lambda_coverage=opts.lambda_coverage + elif opt.model_task == ModelTask.LANGUAGE_MODEL: + compute = onmt.modules.CopyGeneratorLMLossCompute( + criterion, loss_gen, tgt_vocab, opt.copy_loss_by_seqlength, lambda_coverage=opt.lambda_coverage ) else: - raise ValueError(f"No copy generator loss defined for task {opts.model_task}") + raise ValueError(f"No copy generator loss defined for task {opt.model_task}") else: - if opts.model_task == ModelTask.SEQ2SEQ: + if opt.model_task == ModelTask.SEQ2SEQ: compute = NMTLossCompute( criterion, loss_gen, - lambda_coverage=opts.lambda_coverage, - lambda_align=opts.lambda_align, + lambda_coverage=opt.lambda_coverage, + lambda_align=opt.lambda_align, ) - elif opts.model_task == ModelTask.LANGUAGE_MODEL: - assert opts.lambda_align == 0.0, "lamdba_align not supported in LM loss" + elif opt.model_task == ModelTask.LANGUAGE_MODEL: + assert opt.lambda_align == 0.0, "lamdba_align not supported in LM loss" compute = LMLossCompute( criterion, loss_gen, - lambda_coverage=opts.lambda_coverage, - lambda_align=opts.lambda_align, + lambda_coverage=opt.lambda_coverage, + lambda_align=opt.lambda_align, ) else: - raise ValueError(f"No compute loss defined for task {opts.model_task}") + raise ValueError(f"No compute loss defined for task {opt.model_task}") compute.to(device) return compute @@ -159,7 +163,7 @@ def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_ trunc_size (int) : length of truncation window Returns: - A tuple with the loss and a :obj:`mammoth.utils.Statistics` instance. + A tuple with the loss and a :obj:`onmt.utils.Statistics` instance. """ if trunc_size is None: trunc_size = batch.tgt.size(0) - trunc_start @@ -168,7 +172,7 @@ def __call__(self, batch, output, attns, normalization=1.0, shard_size=0, trunc_ if shard_size == 0: loss, stats = self._compute_loss(batch, **shard_state) return loss / float(normalization), stats - batch_stats = mammoth.utils.Statistics() + batch_stats = onmt.utils.Statistics() for shard in shards(shard_state, shard_size): loss, stats = self._compute_loss(batch, **shard) loss.div(float(normalization)).backward() # retain_graph=True) @@ -183,13 +187,13 @@ def _stats(self, loss, scores, labels): labels (:obj:`FloatTensor`): true targets Returns: - :obj:`mammoth.utils.Statistics` : statistics for this batch. + :obj:`onmt.utils.Statistics` : statistics for this batch. """ pred = scores.max(1)[1] non_padding = labels.ne(self.padding_idx) num_correct = pred.eq(labels).masked_select(non_padding).sum().item() num_non_padding = non_padding.sum().item() - return mammoth.utils.Statistics(loss.item(), num_non_padding, num_correct) + return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct) def _bottle(self, _v): return _v.view(-1, _v.size(2)) @@ -300,7 +304,7 @@ def _add_align_shard_state(self, shard_state, batch, range_start, range_end, att pad_tgt_size, batch_size, _ = batch.tgt.size() pad_src_size = batch.src[0].size(0) align_matrix_size = [batch_size, pad_tgt_size, pad_src_size] - ref_align = mammoth.utils.make_batch_align_matrix(align_idx, align_matrix_size, normalize=True) + ref_align = onmt.utils.make_batch_align_matrix(align_idx, align_matrix_size, normalize=True) # NOTE: tgt-src ref alignement that in range_ of shard # (coherent with batch.tgt) shard_state.update( diff --git a/mammoth/utils/misc.py b/onmt/utils/misc.py similarity index 97% rename from mammoth/utils/misc.py rename to onmt/utils/misc.py index 280932f6..36cd5b82 100644 --- a/mammoth/utils/misc.py +++ b/onmt/utils/misc.py @@ -79,11 +79,11 @@ def tile(x, count, dim=0): return x -def use_gpu(opts): +def use_gpu(opt): """ Creates a boolean if gpu used """ - return (hasattr(opts, 'gpu_ranks') and len(opts.gpu_ranks) > 0) or (hasattr(opts, 'gpu') and opts.gpu > -1) + return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or (hasattr(opt, 'gpu') and opt.gpu > -1) def set_random_seed(seed, is_cuda): diff --git a/mammoth/utils/module_splitter.py b/onmt/utils/module_splitter.py similarity index 98% rename from mammoth/utils/module_splitter.py rename to onmt/utils/module_splitter.py index 6e190a3a..738037be 100644 --- a/mammoth/utils/module_splitter.py +++ b/onmt/utils/module_splitter.py @@ -62,7 +62,7 @@ def explode_model(full_ab_model): # stuff necessary to build bilingual models combining modules model_frame = { "vocab": full_ab_model["vocab"], - "opts": full_ab_model["opts"], + "opt": full_ab_model["opt"], "optim": full_ab_model["optim"], } diff --git a/mammoth/utils/optimizers.py b/onmt/utils/optimizers.py similarity index 90% rename from mammoth/utils/optimizers.py rename to onmt/utils/optimizers.py index 57e20ebf..c6043683 100644 --- a/mammoth/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -7,7 +7,7 @@ import types from collections import Counter from math import sqrt -from mammoth.utils.misc import fn_args +from onmt.utils.misc import fn_args from torch.nn.utils import clip_grad_norm_ @@ -60,7 +60,7 @@ def attention_bridge_optimizer(model, task_queue_manager, base_optimizer): return optimizer -def build_torch_optimizer(model, opts, task_queue_manager): +def build_torch_optimizer(model, opt, task_queue_manager): """Builds the PyTorch optimizer. We use the default parameters for Adam that are suggested by @@ -76,105 +76,116 @@ def build_torch_optimizer(model, opts, task_queue_manager): Args: model: The model to optimize. - opts. The dictionary of options. + opt. The dictionary of options. Returns: A ``torch.optim.Optimizer`` instance. """ params = [p for p in model.parameters() if p.requires_grad] - betas = [opts.adam_beta1, opts.adam_beta2] - if opts.optim == 'sgd': - optimizer = optim.SGD(params, lr=opts.learning_rate) - elif opts.optim == 'adagrad': - optimizer = optim.Adagrad( - params, - lr=opts.learning_rate, - initial_accumulator_value=opts.adagrad_accumulator_init, - ) - elif opts.optim == 'adadelta': - optimizer = optim.Adadelta(params, lr=opts.learning_rate) - elif opts.optim == 'adafactor': + betas = [opt.adam_beta1, opt.adam_beta2] + if opt.optim == 'sgd': + optimizer = optim.SGD(params, lr=opt.learning_rate) + elif opt.optim == 'adagrad': + optimizer = optim.Adagrad(params, lr=opt.learning_rate, initial_accumulator_value=opt.adagrad_accumulator_init) + elif opt.optim == 'adadelta': + optimizer = optim.Adadelta(params, lr=opt.learning_rate) + elif opt.optim == 'adafactor': optimizer = attention_bridge_optimizer( model, task_queue_manager, - lambda params: AdaFactorFairSeq(params, weight_decay=opts.weight_decay), + lambda params: AdaFactorFairSeq(params, weight_decay=opt.weight_decay), ) - elif opts.optim == 'adam': + elif opt.optim == 'adam': optimizer = attention_bridge_optimizer( model, task_queue_manager, lambda params: optim.Adam( - params, lr=opts.learning_rate, betas=betas, eps=1e-9, weight_decay=opts.weight_decay + params, lr=opt.learning_rate, betas=betas, eps=1e-9, weight_decay=opt.weight_decay ) ) - elif opts.optim == 'adamw': + elif opt.optim == 'adamw': optimizer = attention_bridge_optimizer( model, task_queue_manager, lambda params: optim.AdamW( - params, lr=opts.learning_rate, betas=betas, eps=1e-9, weight_decay=opts.weight_decay + params, lr=opt.learning_rate, betas=betas, eps=1e-9, weight_decay=opt.weight_decay ) ) - elif opts.optim == 'fusedadam': + elif opt.optim == 'sparseadam': + encs = [] + decs = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + # TODO: Find a better way to check for sparse gradients. + if 'decoder' in name: + # print(name) + decs.append(param) + else: + encs.append(param) + optimizer = MultipleOptimizer( + [optim.Adam(encs, lr=opt.learning_rate, betas=betas, eps=1e-9), AdaFactorFairSeq(decs, warmup_init=True)] + ) + elif opt.optim == 'fusedadam': # we use here a FusedAdam() copy of an old Apex repo - optimizer = FusedAdam(params, lr=opts.learning_rate, betas=betas) - if opts.model_dtype == 'fp16': + optimizer = FusedAdam(params, lr=opt.learning_rate, betas=betas) + if opt.model_dtype == 'fp16': import apex # In this case use the old FusedAdam with FP16_optimizer wrapper - static_loss_scale = opts.loss_scale - dynamic_loss_scale = opts.loss_scale == 0 + static_loss_scale = opt.loss_scale + dynamic_loss_scale = opt.loss_scale == 0 optimizer = apex.contrib.optimizers.FP16_Optimizer( optimizer, static_loss_scale=static_loss_scale, dynamic_loss_scale=dynamic_loss_scale ) else: - raise ValueError('Invalid optimizer type: ' + opts.optim) + raise ValueError('Invalid optimizer type: ' + opt.optim) return optimizer -def make_learning_rate_decay_fn(opts): +def make_learning_rate_decay_fn(opt): """Returns the learning decay function from options.""" - if opts.decay_method == 'noam': - return functools.partial(noam_decay, warmup_steps=opts.warmup_steps, model_dim=opts.model_dim) - elif opts.decay_method == 'noamwd': + if opt.decay_method == 'noam': + return functools.partial(noam_decay, warmup_steps=opt.warmup_steps, model_size=opt.rnn_size) + elif opt.decay_method == 'noamwd': return functools.partial( noamwd_decay, - warmup_steps=opts.warmup_steps, - model_dim=opts.model_dim, - rate=opts.learning_rate_decay, - decay_steps=opts.decay_steps, - start_step=opts.start_decay_steps, + warmup_steps=opt.warmup_steps, + model_size=opt.rnn_size, + rate=opt.learning_rate_decay, + decay_steps=opt.decay_steps, + start_step=opt.start_decay_steps, ) - elif opts.decay_method == 'rsqrt': - return functools.partial(rsqrt_decay, warmup_steps=opts.warmup_steps) - elif opts.decay_method == 'linear_warmup': + elif opt.decay_method == 'rsqrt': + return functools.partial(rsqrt_decay, warmup_steps=opt.warmup_steps) + elif opt.decay_method == 'linear_warmup': return functools.partial( linear_warmup_decay, - warmup_steps=opts.warmup_steps, - rate=opts.learning_rate, - train_steps=opts.train_steps, + warmup_steps=opt.warmup_steps, + rate=opt.learning_rate, + train_steps=opt.train_steps, ) - elif opts.start_decay_steps is not None: + elif opt.start_decay_steps is not None: return functools.partial( exponential_decay, - rate=opts.learning_rate_decay, - decay_steps=opts.decay_steps, - start_step=opts.start_decay_steps, + rate=opt.learning_rate_decay, + decay_steps=opt.decay_steps, + start_step=opt.start_decay_steps, ) -def noam_decay(step, warmup_steps, model_dim): +def noam_decay(step, warmup_steps, model_size): """Learning rate schedule described in https://arxiv.org/pdf/1706.03762.pdf. """ - return model_dim ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5)) + return model_size ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5)) -def noamwd_decay(step, warmup_steps, model_dim, rate, decay_steps, start_step=0): +def noamwd_decay(step, warmup_steps, model_size, rate, decay_steps, start_step=0): """Learning rate schedule optimized for huge batches""" return ( - model_dim ** (-0.5) + model_size ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5)) * rate ** (max(step - start_step + decay_steps, 0) // decay_steps) ) @@ -201,7 +212,7 @@ def linear_warmup_decay(step, warmup_steps, rate, train_steps): class MultipleOptimizer(object): - """Implement multiple optimizers""" + """Implement multiple optimizers needed for sparse adam""" def __init__(self, op, multiOptims_Langs=None): self.optimizers = op @@ -279,24 +290,24 @@ def __init__(self, optimizer, learning_rate, learning_rate_decay_fn=None, max_gr self._scaler = None @classmethod - def from_opts(cls, model, opts, task_queue_manager, checkpoint=None): + def from_opt(cls, model, opt, task_queue_manager, checkpoint=None): """Builds the optimizer from options. Args: cls: The ``Optimizer`` class to instantiate. model: The model to optimize. - opts: The dict of user options. + opt: The dict of user options. checkpoint: An optional checkpoint to load states from. Returns: An ``Optimizer`` instance. """ - optim_opt = opts + optim_opt = opt optim_state_dict = None - if opts.train_from and checkpoint is not None: + if opt.train_from and checkpoint is not None: optim = checkpoint['optim'] - ckpt_opt = checkpoint['opts'] + ckpt_opt = checkpoint['opt'] ckpt_state_dict = {} if isinstance(optim, Optimizer): # Backward compatibility. ckpt_state_dict['training_step'] = optim._step + 1 @@ -305,19 +316,19 @@ def from_opts(cls, model, opts, task_queue_manager, checkpoint=None): else: ckpt_state_dict = optim - if opts.reset_optim == 'none': + if opt.reset_optim == 'none': # Load everything from the checkpoint. optim_opt = ckpt_opt optim_state_dict = ckpt_state_dict - elif opts.reset_optim == 'all': + elif opt.reset_optim == 'all': # Build everything from scratch. pass - elif opts.reset_optim == 'states': + elif opt.reset_optim == 'states': # Reset optimizer, keep options. optim_opt = ckpt_opt optim_state_dict = ckpt_state_dict del optim_state_dict['optimizer'] - elif opts.reset_optim == 'keep_states': + elif opt.reset_optim == 'keep_states': # Reset options, keep optimizer. optim_state_dict = ckpt_state_dict @@ -328,8 +339,8 @@ def from_opts(cls, model, opts, task_queue_manager, checkpoint=None): max_grad_norm=optim_opt.max_grad_norm, ) - if opts.model_dtype == "fp16": - if opts.optim == "fusedadam": + if opt.model_dtype == "fp16": + if opt.optim == "fusedadam": optimizer._fp16 = "legacy" else: optimizer._fp16 = "amp" @@ -723,7 +734,11 @@ def step(self, closure=None, grads=None, output_params=None, scale=1.0, grad_nor if grad is None: grad = p.grad.data if grad.is_sparse: - raise RuntimeError('sparse gradient not supported') + raise RuntimeError( + 'FusedAdam does not support sparse \ + gradients, please consider \ + SparseAdam instead' + ) state = self.state[p] diff --git a/mammoth/utils/parse.py b/onmt/utils/parse.py similarity index 61% rename from mammoth/utils/parse.py rename to onmt/utils/parse.py index 54056cf9..e979e338 100644 --- a/mammoth/utils/parse.py +++ b/onmt/utils/parse.py @@ -4,10 +4,10 @@ import torch import yaml -import mammoth.opts as opts -from mammoth.utils.logging import logger -from mammoth.constants import CorpusName, ModelTask -from mammoth.transforms import AVAILABLE_TRANSFORMS +import onmt.opts as opts +from onmt.utils.logging import logger +from onmt.constants import CorpusName, ModelTask +from onmt.transforms import AVAILABLE_TRANSFORMS RE_NODE_GPU = re.compile(r'\d+:\d+') RE_SRC_TGT = re.compile(r'[^-]+-[^-]+') @@ -23,21 +23,21 @@ def _validate_file(file_path, info): raise IOError(f"Please check path of your {info} file! {file_path}") @classmethod - def _validate_adapters(cls, opts): + def _validate_adapters(cls, opt): """Parse corpora specified in data field of YAML file.""" - if not opts.adapters: + if not opt.adapters: return - adapter_opts = yaml.safe_load(opts.adapters) + adapter_opts = yaml.safe_load(opt.adapters) # TODO: validate adapter opts - opts.adapters = adapter_opts + opt.adapters = adapter_opts @classmethod - def _validate_data(cls, opts): + def _validate_data(cls, opt): """Parse tasks/language-pairs/corpora specified in data field of YAML file.""" - default_transforms = opts.transforms + default_transforms = opt.transforms if len(default_transforms) != 0: logger.info(f"Default transforms: {default_transforms}.") - corpora = yaml.safe_load(opts.tasks) + corpora = yaml.safe_load(opt.data) logger.info("Parsing corpora") n_without_node_gpu = 0 for cname, corpus in corpora.items(): @@ -47,7 +47,7 @@ def _validate_data(cls, opts): if _transforms is None: logger.info(f"Missing transforms field for {cname} data, set to default: {default_transforms}.") corpus['transforms'] = default_transforms - opts.data_task = ModelTask.SEQ2SEQ + opt.data_task = ModelTask.SEQ2SEQ """ # Check path path_src = corpus.get('path_src', None) @@ -57,13 +57,13 @@ def _validate_data(cls, opts): 'tgt path is also required for non language' ' modeling tasks.') else: - opts.data_task = ModelTask.SEQ2SEQ + opt.data_task = ModelTask.SEQ2SEQ if path_tgt is None: logger.warning( "path_tgt is None, it should be set unless the task" " is language modeling" ) - opts.data_task = ModelTask.LANGUAGE_MODEL + opt.data_task = ModelTask.LANGUAGE_MODEL # tgt is src for LM task corpus["path_tgt"] = path_src corpora[cname] = corpus @@ -73,7 +73,7 @@ def _validate_data(cls, opts): """ path_align = corpus.get('path_align', None) if path_align is None: - if hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0: + if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0: raise ValueError(f'Corpus {cname} alignment file path are required when lambda_align > 0.0') corpus['path_align'] = None else: @@ -140,111 +140,111 @@ def _validate_data(cls, opts): assert n_without_node_gpu == 0 or n_without_node_gpu == len(corpora) logger.info(f"Parsed {len(corpora)} corpora from -data.") - opts.tasks = corpora + opt.data = corpora - src_vocab = yaml.safe_load(opts.src_vocab) + src_vocab = yaml.safe_load(opt.src_vocab) logger.info(f"Parsed {len(src_vocab)} vocabs from -src_vocab.") - opts.src_vocab = src_vocab + opt.src_vocab = src_vocab - tgt_vocab = yaml.safe_load(opts.tgt_vocab) + tgt_vocab = yaml.safe_load(opt.tgt_vocab) logger.info(f"Parsed {len(tgt_vocab)} vocabs from -tgt_vocab.") - opts.tgt_vocab = tgt_vocab + opt.tgt_vocab = tgt_vocab @classmethod - def _validate_transforms_opts(cls, opts): + def _validate_transforms_opts(cls, opt): """Check options used by transforms.""" for name, transform_cls in AVAILABLE_TRANSFORMS.items(): - if name in opts._all_transform: - transform_cls._validate_options(opts) + if name in opt._all_transform: + transform_cls._validate_options(opt) @classmethod - def _get_all_transform(cls, opts): + def _get_all_transform(cls, opt): """Should only called after `_validate_data`.""" - all_transforms = set(opts.transforms) - for cname, corpus in opts.tasks.items(): + all_transforms = set(opt.transforms) + for cname, corpus in opt.data.items(): _transforms = set(corpus['transforms']) if len(_transforms) != 0: all_transforms.update(_transforms) - if hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0: + if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0: if not all_transforms.isdisjoint({'sentencepiece', 'bpe', 'onmt_tokenize'}): raise ValueError('lambda_align is not compatible with on-the-fly tokenization.') if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'denoising'}): raise ValueError('lambda_align is not compatible yet with potential token deletion/addition.') - opts._all_transform = all_transforms + opt._all_transform = all_transforms @classmethod - def _get_all_transform_translate(cls, opts): - opts._all_transform = opts.transforms + def _get_all_transform_translate(cls, opt): + opt._all_transform = opt.transforms @classmethod - def _validate_fields_opts(cls, opts, build_vocab_only=False): + def _validate_fields_opts(cls, opt, build_vocab_only=False): """Check options relate to vocab and fields.""" - for cname, corpus in opts.tasks.items(): + for cname, corpus in opt.data.items(): if cname != CorpusName.VALID and corpus["src_feats"] is not None: - assert opts.src_feats_vocab, "-src_feats_vocab is required if using source features." - if isinstance(opts.src_feats_vocab, str): - opts.src_feats_vocab = yaml.safe_load(opts.src_feats_vocab) + assert opt.src_feats_vocab, "-src_feats_vocab is required if using source features." + if isinstance(opt.src_feats_vocab, str): + opt.src_feats_vocab = yaml.safe_load(opt.src_feats_vocab) for feature in corpus["src_feats"].keys(): - assert feature in opts.src_feats_vocab, f"No vocab file set for feature {feature}" + assert feature in opt.src_feats_vocab, f"No vocab file set for feature {feature}" if build_vocab_only: - if not opts.share_vocab: - assert opts.tgt_vocab, "-tgt_vocab is required if not -share_vocab." + if not opt.share_vocab: + assert opt.tgt_vocab, "-tgt_vocab is required if not -share_vocab." return # validation when train: - for key, vocab in opts.src_vocab.items(): + for key, vocab in opt.src_vocab.items(): cls._validate_file(vocab, info=f'src vocab ({key})') - if not opts.share_vocab: - for key, vocab in opts.tgt_vocab.items(): + if not opt.share_vocab: + for key, vocab in opt.tgt_vocab.items(): cls._validate_file(vocab, info=f'tgt vocab ({key})') - # if opts.dump_fields or opts.dump_transforms: - if opts.dump_transforms: + # if opt.dump_fields or opt.dump_transforms: + if opt.dump_transforms: assert ( - opts.save_data + opt.save_data ), "-save_data should be set if set -dump_transforms." # Check embeddings stuff - if opts.both_embeddings is not None: + if opt.both_embeddings is not None: assert ( - opts.src_embeddings is None and opts.tgt_embeddings is None + opt.src_embeddings is None and opt.tgt_embeddings is None ), "You don't need -src_embeddings or -tgt_embeddings \ if -both_embeddings is set." - if any([opts.both_embeddings is not None, opts.src_embeddings is not None, opts.tgt_embeddings is not None]): - assert opts.embeddings_type is not None, "You need to specify an -embedding_type!" + if any([opt.both_embeddings is not None, opt.src_embeddings is not None, opt.tgt_embeddings is not None]): + assert opt.embeddings_type is not None, "You need to specify an -embedding_type!" assert ( - opts.save_data + opt.save_data ), "-save_data should be set if use pretrained embeddings." @classmethod - def _validate_language_model_compatibilities_opts(cls, opts): - if opts.model_task != ModelTask.LANGUAGE_MODEL: + def _validate_language_model_compatibilities_opts(cls, opt): + if opt.model_task != ModelTask.LANGUAGE_MODEL: return logger.info("encoder is not used for LM task") - assert opts.share_vocab and (opts.tgt_vocab is None), "vocab must be shared for LM task" + assert opt.share_vocab and (opt.tgt_vocab is None), "vocab must be shared for LM task" - assert opts.decoder_type == "transformer", "Only transformer decoder is supported for LM task" + assert opt.decoder_type == "transformer", "Only transformer decoder is supported for LM task" @classmethod - def validate_prepare_opts(cls, opts, build_vocab_only=False): + def validate_prepare_opts(cls, opt, build_vocab_only=False): """Validate all options relate to prepare (data/transform/vocab).""" - if opts.n_sample != 0: + if opt.n_sample != 0: assert ( - opts.save_data + opt.save_data ), "-save_data should be set if \ want save samples." - cls._validate_data(opts) - cls._get_all_transform(opts) - cls._validate_transforms_opts(opts) - cls._validate_fields_opts(opts, build_vocab_only=build_vocab_only) + cls._validate_data(opt) + cls._get_all_transform(opt) + cls._validate_transforms_opts(opt) + cls._validate_fields_opts(opt, build_vocab_only=build_vocab_only) @classmethod - def validate_model_opts(cls, opts): - cls._validate_language_model_compatibilities_opts(opts) + def validate_model_opts(cls, opt): + cls._validate_language_model_compatibilities_opts(opt) class ArgumentParser(cfargparse.ArgumentParser, DataOptsCheckerMixin): @@ -270,103 +270,108 @@ def defaults(cls, *args): return defaults @classmethod - def update_model_opts(cls, model_opts): - cls._validate_adapters(model_opts) - if model_opts.model_dim > 0: - model_opts.model_dim = model_opts.model_dim - model_opts.model_dim = model_opts.model_dim + def update_model_opts(cls, model_opt): + cls._validate_adapters(model_opt) + if model_opt.word_vec_size > 0: + model_opt.src_word_vec_size = model_opt.word_vec_size + model_opt.tgt_word_vec_size = model_opt.word_vec_size # Backward compatibility with "fix_word_vecs_*" opts - if hasattr(model_opts, 'fix_word_vecs_enc'): - model_opts.freeze_word_vecs_enc = model_opts.fix_word_vecs_enc - if hasattr(model_opts, 'fix_word_vecs_dec'): - model_opts.freeze_word_vecs_dec = model_opts.fix_word_vecs_dec + if hasattr(model_opt, 'fix_word_vecs_enc'): + model_opt.freeze_word_vecs_enc = model_opt.fix_word_vecs_enc + if hasattr(model_opt, 'fix_word_vecs_dec'): + model_opt.freeze_word_vecs_dec = model_opt.fix_word_vecs_dec - if model_opts.layers > 0: + if model_opt.layers > 0: raise Exception('--layers is deprecated') - model_opts.brnn = model_opts.encoder_type == "brnn" + if model_opt.rnn_size > 0: + model_opt.enc_rnn_size = model_opt.rnn_size + model_opt.dec_rnn_size = model_opt.rnn_size - if model_opts.copy_attn_type is None: - model_opts.copy_attn_type = model_opts.global_attention + model_opt.brnn = model_opt.encoder_type == "brnn" - if model_opts.alignment_layer is None: - model_opts.alignment_layer = -2 - model_opts.lambda_align = 0.0 - model_opts.full_context_alignment = False + if model_opt.copy_attn_type is None: + model_opt.copy_attn_type = model_opt.global_attention + + if model_opt.alignment_layer is None: + model_opt.alignment_layer = -2 + model_opt.lambda_align = 0.0 + model_opt.full_context_alignment = False @classmethod - def validate_model_opts(cls, model_opts): - assert model_opts.model_type in ["text"], "Unsupported model type %s" % model_opts.model_type + def validate_model_opts(cls, model_opt): + assert model_opt.model_type in ["text"], "Unsupported model type %s" % model_opt.model_type # encoder and decoder should be same sizes - # assert same_size, "The encoder and decoder rnns must be the same size for now" + same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size + assert same_size, "The encoder and decoder rnns must be the same size for now" - if model_opts.share_embeddings: - if model_opts.model_type != "text": + assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, "Using SRU requires -gpu_ranks set." + if model_opt.share_embeddings: + if model_opt.model_type != "text": raise AssertionError("--share_embeddings requires --model_type text.") - if model_opts.lambda_align > 0.0: - assert model_opts.decoder_type == 'transformer', "Only transformer is supported to joint learn alignment." + if model_opt.lambda_align > 0.0: + assert model_opt.decoder_type == 'transformer', "Only transformer is supported to joint learn alignment." assert ( - model_opts.alignment_layer < model_opts.dec_layers - and model_opts.alignment_layer >= -model_opts.dec_layers + model_opt.alignment_layer < model_opt.dec_layers and model_opt.alignment_layer >= -model_opt.dec_layers ), "N° alignment_layer should be smaller than number of layers." logger.info( "Joint learn alignment at layer [{}] " "with {} heads in full_context '{}'.".format( - model_opts.alignment_layer, model_opts.alignment_heads, model_opts.full_context_alignment + model_opt.alignment_layer, model_opt.alignment_heads, model_opt.full_context_alignment ) ) @classmethod def ckpt_model_opts(cls, ckpt_opt): - # Load default opts values, then overwrite with the opts in + # Load default opt values, then overwrite with the opts in # the checkpoint. That way, if there are new options added, # the defaults are used. - the_opts = cls.defaults(opts.model_opts) - the_opts.__dict__.update(ckpt_opt.__dict__) - return the_opts + opt = cls.defaults(opts.model_opts) + opt.__dict__.update(ckpt_opt.__dict__) + return opt @classmethod - def validate_train_opts(cls, opts): - if opts.epochs: + def validate_train_opts(cls, opt): + if opt.epochs: raise AssertionError("-epochs is deprecated please use -train_steps.") - if opts.truncated_decoder > 0 and max(opts.accum_count) > 1: + if opt.truncated_decoder > 0 and max(opt.accum_count) > 1: raise AssertionError("BPTT is not compatible with -accum > 1") - if opts.gpuid: + if opt.gpuid: raise AssertionError("gpuid is deprecated see world_size and gpu_ranks") - if torch.cuda.is_available() and not opts.gpu_ranks: + if torch.cuda.is_available() and not opt.gpu_ranks: logger.warn("You have a CUDA device, should run with -gpu_ranks") - if opts.world_size < len(opts.gpu_ranks): + if opt.world_size < len(opt.gpu_ranks): raise AssertionError("parameter counts of -gpu_ranks must be less or equal than -world_size.") - if len(opts.gpu_ranks) > 0 and opts.world_size == len(opts.gpu_ranks) and min(opts.gpu_ranks) > 0: + if len(opt.gpu_ranks) > 0 and opt.world_size == len(opt.gpu_ranks) and min(opt.gpu_ranks) > 0: raise AssertionError( "-gpu_ranks should have master(=0) rank unless -world_size is greater than len(gpu_ranks)." ) - assert len(opts.dropout) == len(opts.dropout_steps), "Number of dropout values must match accum_steps values" + assert len(opt.dropout) == len(opt.dropout_steps), "Number of dropout values must match accum_steps values" - assert len(opts.attention_dropout) == len( - opts.dropout_steps + assert len(opt.attention_dropout) == len( + opt.dropout_steps ), "Number of attention_dropout values must match accum_steps values" - assert len(opts.accum_count) == len( - opts.accum_steps + assert len(opt.accum_count) == len( + opt.accum_steps ), 'Number of accum_count values must match number of accum_steps' - if opts.update_vocab: - assert opts.train_from, "-update_vocab needs -train_from option" - assert opts.reset_optim in ['states', 'all'], '-update_vocab needs -reset_optim "states" or "all"' + if opt.update_vocab: + assert opt.train_from, "-update_vocab needs -train_from option" + assert opt.reset_optim in ['states', 'all'], '-update_vocab needs -reset_optim "states" or "all"' @classmethod - def validate_translate_opts(cls, opts): - opts.src_feats = eval(opts.src_feats) if opts.src_feats else {} + def validate_translate_opts(cls, opt): + opt.src_feats = eval(opt.src_feats) if opt.src_feats else {} @classmethod - def validate_translate_opts_dynamic(cls, opts): + def validate_translate_opts_dynamic(cls, opt): # It comes from training - # TODO: needs to be added as inference opts - opts.share_vocab = False + # TODO: needs to be added as inference opt + opt.share_vocab = False - opts.stack = yaml.safe_load(opts.stack) + opt.stack = yaml.safe_load(opt.stack) diff --git a/mammoth/utils/report_manager.py b/onmt/utils/report_manager.py similarity index 86% rename from mammoth/utils/report_manager.py rename to onmt/utils/report_manager.py index 822938d0..35554a8a 100644 --- a/mammoth/utils/report_manager.py +++ b/onmt/utils/report_manager.py @@ -2,28 +2,28 @@ import time from datetime import datetime -import mammoth +import onmt -from mammoth.utils.logging import logger +from onmt.utils.logging import logger -def build_report_manager(opts, node_rank, local_rank): - # Vanilla mammoth has here an additional gpu_rank <= 0 +def build_report_manager(opt, node_rank, local_rank): + # Vanilla onmt has here an additional gpu_rank <= 0 # which would cause only the first GPU of each node to log. # This change allows all GPUs to log. # Because tensorboard does not allow multiple processes writing into the same directory, # each device is treated as a separate run. - if opts.tensorboard: + if opt.tensorboard: from torch.utils.tensorboard import SummaryWriter - if not hasattr(opts, 'tensorboard_log_dir_dated'): - opts.tensorboard_log_dir_dated = opts.tensorboard_log_dir + datetime.now().strftime("/%b-%d_%H-%M-%S") + if not hasattr(opt, 'tensorboard_log_dir_dated'): + opt.tensorboard_log_dir_dated = opt.tensorboard_log_dir + datetime.now().strftime("/%b-%d_%H-%M-%S") - writer = SummaryWriter(f'{opts.tensorboard_log_dir_dated}-rank{node_rank}:{local_rank}', comment="Unmt") + writer = SummaryWriter(f'{opt.tensorboard_log_dir_dated}-rank{node_rank}:{local_rank}', comment="Unmt") else: writer = None - report_mgr = ReportMgr(opts.report_every, start_time=-1, tensorboard_writer=writer) + report_mgr = ReportMgr(opt.report_every, start_time=-1, tensorboard_writer=writer) return report_mgr @@ -73,9 +73,9 @@ def report_training(self, step, num_steps, learning_rate, patience, report_stats if step % self.report_every == 0: # if multigpu: # report_stats = \ - # mammoth.utils.Statistics.all_gather_stats(report_stats) + # onmt.utils.Statistics.all_gather_stats(report_stats) self._report_training(step, num_steps, learning_rate, patience, report_stats) - return mammoth.utils.Statistics() + return onmt.utils.Statistics() else: return report_stats @@ -125,7 +125,7 @@ def _report_training(self, step, num_steps, learning_rate, patience, report_stat report_stats.output(step, num_steps, learning_rate, self.start_time) self.maybe_log_tensorboard(report_stats, "progress", learning_rate, patience, step) - report_stats = mammoth.utils.Statistics() + report_stats = onmt.utils.Statistics() return report_stats diff --git a/onmt/utils/rnn_factory.py b/onmt/utils/rnn_factory.py new file mode 100644 index 00000000..f35c48e3 --- /dev/null +++ b/onmt/utils/rnn_factory.py @@ -0,0 +1,17 @@ +""" + RNN tools +""" +import torch.nn as nn +import onmt.models + + +def rnn_factory(rnn_type, **kwargs): + """rnn factory, Use pytorch version when available.""" + no_pack_padded_seq = False + if rnn_type == "SRU": + # SRU doesn't support PackedSequence. + no_pack_padded_seq = True + rnn = onmt.models.sru.SRU(**kwargs) + else: + rnn = getattr(nn, rnn_type)(**kwargs) + return rnn, no_pack_padded_seq diff --git a/mammoth/utils/statistics.py b/onmt/utils/statistics.py similarity index 98% rename from mammoth/utils/statistics.py rename to onmt/utils/statistics.py index ff8292a5..17f3e151 100644 --- a/mammoth/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -7,7 +7,7 @@ from collections import Counter from torch.linalg import norm -from mammoth.utils.logging import logger +from onmt.utils.logging import logger class Statistics(object): @@ -65,7 +65,7 @@ def all_gather_stats_list(stat_list, max_size=4096): our_stats(list([`Statistics`])): list of updated stats """ from torch.distributed import get_rank - from mammoth.distributed import all_gather_list + from onmt.utils.distributed import all_gather_list # Get a list of world_size lists with len(stat_list) Statistics objects all_stats = all_gather_list(stat_list, max_size=max_size) diff --git a/server.py b/server.py index 43c54b1c..2e078ba6 100644 --- a/server.py +++ b/server.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from mammoth.bin.server import main +from onmt.bin.server import main if __name__ == "__main__": diff --git a/setup.py b/setup.py index 1dc4e9b0..555141bb 100644 --- a/setup.py +++ b/setup.py @@ -7,11 +7,11 @@ long_description = f.read() setup( - name='mammoth', - description='Massively Multilingual Modular Open Translation @ Helsinki', + name='OpenNMT-py', + description='A python implementation of OpenNMT', long_description=long_description, long_description_content_type='text/markdown', - version='0.1', + version='2.2.0', packages=find_packages(), project_urls={ "Documentation": "http://opennmt.net/OpenNMT-py/", @@ -35,12 +35,12 @@ ], entry_points={ "console_scripts": [ - # "onmt_server=mammoth.bin.server:main", - "mammoth_train=mammoth.bin.train:main", - "mammoth_translate=mammoth.bin.translate:main", - # "onmt_release_model=mammoth.bin.release_model:main", - # "onmt_average_models=mammoth.bin.average_models:main", - # "onmt_build_vocab=mammoth.bin.build_vocab:main", + "onmt_server=onmt.bin.server:main", + "onmt_train=onmt.bin.train:main", + "onmt_translate=onmt.bin.translate:main", + "onmt_release_model=onmt.bin.release_model:main", + "onmt_average_models=onmt.bin.average_models:main", + "onmt_build_vocab=onmt.bin.build_vocab:main", ], }, ) diff --git a/test_communication/test.py b/test_communication/test.py index 932a1aed..4895fabd 100644 --- a/test_communication/test.py +++ b/test_communication/test.py @@ -7,9 +7,9 @@ import timeout_decorator -import mammoth -from mammoth.bin.train import train -from mammoth.utils.parse import ArgumentParser +import onmt +from onmt.bin.train import train +from onmt.utils.parse import ArgumentParser import torch.multiprocessing as mp @@ -20,7 +20,7 @@ class TestTraining(TestCase): @classmethod def setUpClass(cls) -> None: cls.parser = ArgumentParser(description="train.py") - mammoth.opts.train_opts(cls.parser) + onmt.opts.train_opts(cls.parser) # clear output folders for folder in ["models", "tensorboard"]: if os.path.exists(folder): @@ -35,13 +35,13 @@ def tearDown(self) -> None: child_process.kill() @staticmethod - def _get_model_components(opts) -> List[str]: + def _get_model_components(opt) -> List[str]: # N.B: These components are only valid for very vanilla language-specific xcoder with fully shared AB models - components_enc = [f"encoder_0_{src_lang}" for src_lang in ast.literal_eval(opts.src_vocab).keys()] - components_dec = [f"encoder_0_{tgt_lang}" for tgt_lang in ast.literal_eval(opts.tgt_vocab).keys()] - components_gen = [f"generator_{tgt_lang}" for tgt_lang in ast.literal_eval(opts.tgt_vocab).keys()] - components_src_emb = [f"src_embeddings_{src_lang}" for src_lang in ast.literal_eval(opts.src_vocab).keys()] - components_tgt_emb = [f"tgt_embeddings_{tgt_lang}" for tgt_lang in ast.literal_eval(opts.tgt_vocab).keys()] + components_enc = [f"encoder_0_{src_lang}" for src_lang in ast.literal_eval(opt.src_vocab).keys()] + components_dec = [f"encoder_0_{tgt_lang}" for tgt_lang in ast.literal_eval(opt.tgt_vocab).keys()] + components_gen = [f"generator_{tgt_lang}" for tgt_lang in ast.literal_eval(opt.tgt_vocab).keys()] + components_src_emb = [f"src_embeddings_{src_lang}" for src_lang in ast.literal_eval(opt.src_vocab).keys()] + components_tgt_emb = [f"tgt_embeddings_{tgt_lang}" for tgt_lang in ast.literal_eval(opt.tgt_vocab).keys()] return [ "frame", "attention_bridge", @@ -55,7 +55,7 @@ def _get_model_components(opts) -> List[str]: @timeout_decorator.timeout(60) def test_training_1gpu_4pairs(self): out_model_prefix = "wmt_1gpu_4pairs" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -72,14 +72,14 @@ def test_training_1gpu_4pairs(self): "0:0", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -91,7 +91,7 @@ def test_training_1gpu_4pairs(self): @timeout_decorator.timeout(60) def test_training_1gpu_4pairs_ab_lin(self): out_model_prefix = "wmt_1gpu_4pairs_lin" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -114,14 +114,14 @@ def test_training_1gpu_4pairs_ab_lin(self): "0:0", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -133,7 +133,7 @@ def test_training_1gpu_4pairs_ab_lin(self): @timeout_decorator.timeout(60) def test_training_1gpu_4pairs_ab_ff(self): out_model_prefix = "wmt_1gpu_4pairs_ff" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -154,14 +154,14 @@ def test_training_1gpu_4pairs_ab_ff(self): "0:0", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -173,7 +173,7 @@ def test_training_1gpu_4pairs_ab_ff(self): @timeout_decorator.timeout(60) def test_training_1gpu_4pairs_ab_tf(self): out_model_prefix = "wmt_1gpu_4pairs_tf" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -194,14 +194,14 @@ def test_training_1gpu_4pairs_ab_tf(self): "0:0", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -213,7 +213,7 @@ def test_training_1gpu_4pairs_ab_tf(self): @timeout_decorator.timeout(60) def test_training_1gpu_4pairs_ab_simple(self): out_model_prefix = "wmt_1gpu_4pairs_simple" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -234,14 +234,14 @@ def test_training_1gpu_4pairs_ab_simple(self): "0:0", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -253,7 +253,7 @@ def test_training_1gpu_4pairs_ab_simple(self): @timeout_decorator.timeout(60) def test_training_1gpu_4pairs_ab_perceiver(self): out_model_prefix = "wmt_1gpu_4pairs_perceiver" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -274,14 +274,14 @@ def test_training_1gpu_4pairs_ab_perceiver(self): "0:0", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -293,7 +293,7 @@ def test_training_1gpu_4pairs_ab_perceiver(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs(self): out_model_prefix = "wmt_2gpus_4pairs" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -311,14 +311,14 @@ def test_training_2gpus_4pairs(self): "0:1", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -330,7 +330,7 @@ def test_training_2gpus_4pairs(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs_ab_lin(self): out_model_prefix = "wmt_2gpus_4pairs_lin" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -354,14 +354,14 @@ def test_training_2gpus_4pairs_ab_lin(self): "0:1", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -373,7 +373,7 @@ def test_training_2gpus_4pairs_ab_lin(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs_ab_ff(self): out_model_prefix = "wmt_2gpus_4pairs_ff" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -395,14 +395,14 @@ def test_training_2gpus_4pairs_ab_ff(self): "0:1", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -414,7 +414,7 @@ def test_training_2gpus_4pairs_ab_ff(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs_ab_tf(self): out_model_prefix = "wmt_2gpus_4pairs_tf" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -436,14 +436,14 @@ def test_training_2gpus_4pairs_ab_tf(self): "0:1", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -455,7 +455,7 @@ def test_training_2gpus_4pairs_ab_tf(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs_ab_simple(self): out_model_prefix = "wmt_2gpus_4pairs_simple" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -479,14 +479,14 @@ def test_training_2gpus_4pairs_ab_simple(self): "0:1", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -498,7 +498,7 @@ def test_training_2gpus_4pairs_ab_simple(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs_ab_perceiver(self): out_model_prefix = "wmt_2gpus_4pairs_perceiver" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -520,14 +520,14 @@ def test_training_2gpus_4pairs_ab_perceiver(self): "0:1", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -539,7 +539,7 @@ def test_training_2gpus_4pairs_ab_perceiver(self): @timeout_decorator.timeout(60) def test_training_2gpus_4pairs_crossed(self): out_model_prefix = "wmt_2gpus_4pairs_crossed" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -557,14 +557,14 @@ def test_training_2gpus_4pairs_crossed(self): "0:0", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -576,7 +576,7 @@ def test_training_2gpus_4pairs_crossed(self): @timeout_decorator.timeout(60) def test_training_4gpus_4pairs(self): out_model_prefix = "wmt_4gpus_4pairs" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -596,14 +596,14 @@ def test_training_4gpus_4pairs(self): "0:3", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -615,7 +615,7 @@ def test_training_4gpus_4pairs(self): @timeout_decorator.timeout(120) def test_training_3gpus_12pairs(self): out_model_prefix = "wmt_3gpus_12pairs" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_12pairs.yml", @@ -642,14 +642,14 @@ def test_training_3gpus_12pairs(self): "0:2", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -661,7 +661,7 @@ def test_training_3gpus_12pairs(self): @timeout_decorator.timeout(120) def test_training_3gpus_21pairs(self): out_model_prefix = "wmt_3gpus_21pairs" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_21pairs.yml", @@ -697,14 +697,14 @@ def test_training_3gpus_21pairs(self): "0:2", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -716,7 +716,7 @@ def test_training_3gpus_21pairs(self): @timeout_decorator.timeout(120) def test_training_4gpus_12pairs(self): out_model_prefix = "wmt_4gpus_12pairs" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_12pairs.yml", @@ -744,14 +744,14 @@ def test_training_4gpus_12pairs(self): "0:3", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -763,7 +763,7 @@ def test_training_4gpus_12pairs(self): @timeout_decorator.timeout(120) def test_training_4gpus_24pairs(self): out_model_prefix = "wmt_4gpus_24pairs" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_24pairs.yml", @@ -803,14 +803,14 @@ def test_training_4gpus_24pairs(self): "0:3", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -822,7 +822,7 @@ def test_training_4gpus_24pairs(self): @timeout_decorator.timeout(120) def test_training_1gpu_tensorboard(self): out_model_prefix = "wmt_1gpu_tb" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -843,14 +843,14 @@ def test_training_1gpu_tensorboard(self): "0:0", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -868,7 +868,7 @@ def test_training_1gpu_tensorboard(self): @timeout_decorator.timeout(120) def test_training_2gpus_tensorboard(self): out_model_prefix = "wmt_2gpus_tb" - opts, _ = self.parser.parse_known_args( + opt, _ = self.parser.parse_known_args( [ "-config", "config/wmt_4pairs.yml", @@ -890,14 +890,14 @@ def test_training_2gpus_tensorboard(self): "0:1", ] ) - components = self._get_model_components(opts) + components = self._get_model_components(opt) out_files = ["models/{}_step_4_{}.pt".format(out_model_prefix, cmp) for cmp in components] for out_file in out_files: if os.path.exists(out_file): logger.info("Removing file {}".format(out_file)) os.remove(out_file) logger.info("Launch training") - train(opts) + train(opt) for cmp in components: self.assertNotIn("{}_step_2_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) self.assertIn("{}_step_4_{}.pt".format(out_model_prefix, cmp), os.listdir("models")) @@ -919,14 +919,14 @@ def test_training_2gpus_tensorboard(self): # @classmethod # def setUpClass(cls) -> None: # cls.parser = ArgumentParser(description="translate.py") -# mammoth.opts.config_opts(cls.parser) -# mammoth.opts.translate_opts(cls.parser) -# mammoth.opts.build_bilingual_model(cls.parser) +# onmt.opts.config_opts(cls.parser) +# onmt.opts.translate_opts(cls.parser) +# onmt.opts.build_bilingual_model(cls.parser) # # def test_translate(self): # # TODO: train model instead of loading one the one used now, # # remove all absolute paths, add test data in the repo -# opts, _ = self.parser.parse_known_args( +# opt, _ = self.parser.parse_known_args( # [ # "-gpu", # "0", @@ -945,4 +945,4 @@ def test_training_2gpus_tensorboard(self): # "-use_attention_bridge", # ] # ) -# translate(opts) +# translate(opt) diff --git a/tools/attention_bank.py b/tools/attention_bank.py index 38bdb866..d244a9ff 100644 --- a/tools/attention_bank.py +++ b/tools/attention_bank.py @@ -7,12 +7,12 @@ import torch import tqdm -from mammoth.inputters.dataset import ParallelCorpus -from mammoth.inputters.dataloader import build_dataloader -from mammoth.model_builder import load_test_multitask_model -from mammoth.opts import build_bilingual_model, _add_dynamic_transform_opts -from mammoth.transforms import get_transforms_cls, make_transforms, TransformPipe -from mammoth.utils.parse import ArgumentParser +from onmt.inputters.dataset import ParallelCorpus +from onmt.inputters.dataloader import build_dataloader +from onmt.model_builder import load_test_multitask_model +from onmt.opts import build_bilingual_model, _add_dynamic_transform_opts +from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe +from onmt.utils.parse import ArgumentParser def get_opts(): @@ -78,7 +78,7 @@ def _extract(sentences_file, model, vocabs_dict, transforms, enc_id, batch_size= yield (memory_bank, src_lengths) -def extract(opts, vocabs_dict, model, model_opts, transforms): +def extract(opts, vocabs_dict, model, model_opt, transforms): """Compute representations drawn from the encoder and save them to file.""" sentence_reps = [] for src, src_length in _extract( @@ -95,7 +95,7 @@ def extract(opts, vocabs_dict, model, model_opts, transforms): torch.save(sentence_reps, opts.dump_file) -def estimate(opts, vocabs_dict, model, model_opts, transforms): +def estimate(opts, vocabs_dict, model, model_opt, transforms): """Estimate the matrix-variate distribution of representations drawn from the encoder.""" try: import sklearn.covariance @@ -134,7 +134,7 @@ def estimate(opts, vocabs_dict, model, model_opts, transforms): # return sampling_fn -def classify(opts, vocabs_dict, model, model_opts, transforms): +def classify(opts, vocabs_dict, model, model_opt, transforms): """Learn a simple SGD classifier using representations drawn from the encoder.""" try: import sklearn.linear_model @@ -224,7 +224,7 @@ def main(): # ArgumentParser.validate_translate_opts_dynamic(opts) opts.enc_id = opts.enc_id or opts.src_lang - vocabs_dict, model, model_opts = load_test_multitask_model(opts, opts.model) + vocabs_dict, model, model_opt = load_test_multitask_model(opts, opts.model) command_fn = { fn.__name__: fn for fn in [extract, estimate, classify] @@ -238,7 +238,7 @@ def main(): ] transform = TransformPipe.build_from(data_transform) - command_fn(opts, vocabs_dict, model.to(opts.device), model_opts, transform) + command_fn(opts, vocabs_dict, model.to(opts.device), model_opt, transform) if __name__ == '__main__': diff --git a/tools/average_models.py b/tools/average_models.py index ce714f92..9e053a8c 100755 --- a/tools/average_models.py +++ b/tools/average_models.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from mammoth.bin.average_models import main +from onmt.bin.average_models import main if __name__ == "__main__": diff --git a/tools/embeddings_to_torch.py b/tools/embeddings_to_torch.py index 00f2d981..3bdb1fb2 100755 --- a/tools/embeddings_to_torch.py +++ b/tools/embeddings_to_torch.py @@ -4,7 +4,7 @@ import six import argparse import torch -from mammoth.utils.logging import init_logger, logger +from onmt.utils.logging import init_logger, logger # FIXME haven't touched that file yet... @@ -79,48 +79,48 @@ def main(): parser.add_argument('-verbose', action="store_true", default=False) parser.add_argument('-skip_lines', type=int, default=0, help="Skip first lines of the embedding file") parser.add_argument('-type', choices=["GloVe", "word2vec"], default="GloVe") - opts = parser.parse_args() + opt = parser.parse_args() - enc_vocab, dec_vocab = get_vocabs(opts.dict_file) + enc_vocab, dec_vocab = get_vocabs(opt.dict_file) # Read in embeddings - skip_lines = 1 if opts.type == "word2vec" else opts.skip_lines - if opts.emb_file_both is not None: - if opts.emb_file_enc is not None: + skip_lines = 1 if opt.type == "word2vec" else opt.skip_lines + if opt.emb_file_both is not None: + if opt.emb_file_enc is not None: raise ValueError("If --emb_file_both is passed in, you should not" "set --emb_file_enc.") - if opts.emb_file_dec is not None: + if opt.emb_file_dec is not None: raise ValueError("If --emb_file_both is passed in, you should not" "set --emb_file_dec.") set_of_src_and_tgt_vocab = set(enc_vocab.stoi.keys()) | set(dec_vocab.stoi.keys()) - logger.info("Reading encoder and decoder embeddings from {}".format(opts.emb_file_both)) - src_vectors, total_vec_count = read_embeddings(opts.emb_file_both, skip_lines, set_of_src_and_tgt_vocab) + logger.info("Reading encoder and decoder embeddings from {}".format(opt.emb_file_both)) + src_vectors, total_vec_count = read_embeddings(opt.emb_file_both, skip_lines, set_of_src_and_tgt_vocab) tgt_vectors = src_vectors logger.info("\tFound {} total vectors in file".format(total_vec_count)) else: - if opts.emb_file_enc is None: + if opt.emb_file_enc is None: raise ValueError( "If --emb_file_enc not provided. Please specify " "the file with encoder embeddings, or pass in " "--emb_file_both" ) - if opts.emb_file_dec is None: + if opt.emb_file_dec is None: raise ValueError( "If --emb_file_dec not provided. Please specify " "the file with encoder embeddings, or pass in " "--emb_file_both" ) - logger.info("Reading encoder embeddings from {}".format(opts.emb_file_enc)) - src_vectors, total_vec_count = read_embeddings(opts.emb_file_enc, skip_lines, filter_set=enc_vocab.stoi) + logger.info("Reading encoder embeddings from {}".format(opt.emb_file_enc)) + src_vectors, total_vec_count = read_embeddings(opt.emb_file_enc, skip_lines, filter_set=enc_vocab.stoi) logger.info("\tFound {} total vectors in file.".format(total_vec_count)) - logger.info("Reading decoder embeddings from {}".format(opts.emb_file_dec)) - tgt_vectors, total_vec_count = read_embeddings(opts.emb_file_dec, skip_lines, filter_set=dec_vocab.stoi) + logger.info("Reading decoder embeddings from {}".format(opt.emb_file_dec)) + tgt_vectors, total_vec_count = read_embeddings(opt.emb_file_dec, skip_lines, filter_set=dec_vocab.stoi) logger.info("\tFound {} total vectors in file".format(total_vec_count)) logger.info("After filtering to vectors in vocab:") logger.info("\t* enc: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(enc_vocab, src_vectors)) logger.info("\t* dec: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(dec_vocab, tgt_vectors)) # Write to file - enc_output_file = opts.output_file + ".enc.pt" - dec_output_file = opts.output_file + ".dec.pt" + enc_output_file = opt.output_file + ".enc.pt" + dec_output_file = opt.output_file + ".dec.pt" logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s" % (enc_output_file, dec_output_file)) torch.save(convert_to_torch_tensor(src_vectors, enc_vocab), enc_output_file) torch.save(convert_to_torch_tensor(tgt_vectors, dec_vocab), dec_output_file) diff --git a/tools/extract_embeddings.py b/tools/extract_embeddings.py index 49ecbca9..41b2ded4 100644 --- a/tools/extract_embeddings.py +++ b/tools/extract_embeddings.py @@ -2,14 +2,14 @@ import torch -import mammoth -import mammoth.model_builder +import onmt +import onmt.model_builder -from mammoth.utils.parse import ArgumentParser -import mammoth.opts +from onmt.utils.parse import ArgumentParser +import onmt.opts -from mammoth.utils.misc import use_gpu -from mammoth.utils.logging import init_logger, logger +from onmt.utils.misc import use_gpu +from onmt.utils.logging import init_logger, logger parser = argparse.ArgumentParser(description='translate.py') @@ -29,31 +29,31 @@ def write_embeddings(filename, dict, embeddings): def main(): dummy_parser = argparse.ArgumentParser(description='train.py') - mammoth.opts.model_opts(dummy_parser) + onmt.opts.model_opts(dummy_parser) dummy_opt = dummy_parser.parse_known_args([])[0] - opts = parser.parse_args() - opts.cuda = opts.gpu > -1 - if opts.cuda: - torch.cuda.set_device(opts.gpu) + opt = parser.parse_args() + opt.cuda = opt.gpu > -1 + if opt.cuda: + torch.cuda.set_device(opt.gpu) # Add in default model arguments, possibly added since training. - checkpoint = torch.load(opts.model, map_location=lambda storage, loc: storage) - model_opts = checkpoint['opts'] + checkpoint = torch.load(opt.model, map_location=lambda storage, loc: storage) + model_opt = checkpoint['opt'] fields = checkpoint['vocab'] src_dict = fields['src'].base_field.vocab # assumes src is text tgt_dict = fields['tgt'].base_field.vocab - model_opts = checkpoint['opts'] + model_opt = checkpoint['opt'] for arg in dummy_opt.__dict__: - if arg not in model_opts: - model_opts.__dict__[arg] = dummy_opt.__dict__[arg] + if arg not in model_opt: + model_opt.__dict__[arg] = dummy_opt.__dict__[arg] # build_base_model expects updated and validated opts - ArgumentParser.update_model_opts(model_opts) - ArgumentParser.validate_model_opts(model_opts) + ArgumentParser.update_model_opts(model_opt) + ArgumentParser.validate_model_opts(model_opt) - model = mammoth.model_builder.build_base_model(model_opts, fields, use_gpu(opts), checkpoint) + model = onmt.model_builder.build_base_model(model_opt, fields, use_gpu(opt), checkpoint) encoder = model.encoder # no encoder for LM task decoder = model.decoder @@ -61,10 +61,10 @@ def main(): decoder_embeddings = decoder.embeddings.word_lut.weight.data.tolist() logger.info("Writing source embeddings") - write_embeddings(opts.output_dir + "/src_embeddings.txt", src_dict, encoder_embeddings) + write_embeddings(opt.output_dir + "/src_embeddings.txt", src_dict, encoder_embeddings) logger.info("Writing target embeddings") - write_embeddings(opts.output_dir + "/tgt_embeddings.txt", tgt_dict, decoder_embeddings) + write_embeddings(opt.output_dir + "/tgt_embeddings.txt", tgt_dict, decoder_embeddings) logger.info('... done.') logger.info('Converting model...') diff --git a/tools/extract_vocabulary.py b/tools/extract_vocabulary.py index 3c062e54..e003cc81 100644 --- a/tools/extract_vocabulary.py +++ b/tools/extract_vocabulary.py @@ -60,12 +60,12 @@ def main(): help="""Specifies 'src' or 'tgt' side for 'field' file_type.""", ) - opts = parser.parse_args() + opt = parser.parse_args() vocabulary = {} - if opts.file_type == 'text': + if opt.file_type == 'text': print("Reading input file...") - for batch in read_files_batch(opts.file): + for batch in read_files_batch(opt.file): for sentence in batch: for w in sentence: if w in vocabulary: @@ -74,19 +74,19 @@ def main(): vocabulary[w] = 1 print("Writing vocabulary file...") - with open(opts.out_file, "w") as f: + with open(opt.out_file, "w") as f: for w, count in sorted(vocabulary.items(), key=lambda x: x[1], reverse=True): f.write("{0}\n".format(w)) else: - if opts.side not in ['src', 'tgt']: + if opt.side not in ['src', 'tgt']: raise ValueError("If using -file_type='field', specifies 'src' or 'tgt' argument for -side.") import torch print("Reading input file...") - if not len(opts.file) == 1: + if not len(opt.file) == 1: raise ValueError("If using -file_type='field', only pass one argument for -file.") - vocabs = torch.load(opts.file[0]) - voc = dict(vocabs)[opts.side] + vocabs = torch.load(opt.file[0]) + voc = dict(vocabs)[opt.side] try: word_list = voc[0][1].base_field.vocab.itos @@ -94,7 +94,7 @@ def main(): word_list = voc[0][1].vocab.itos print("Writing vocabulary file...") - with open(opts.out_file, "wb") as f: + with open(opt.out_file, "wb") as f: for w in word_list: f.write(u"{0}\n".format(w).encode("utf-8")) diff --git a/tools/release_model.py b/tools/release_model.py index dd437517..b716b115 100644 --- a/tools/release_model.py +++ b/tools/release_model.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from mammoth.bin.release_model import main +from onmt.bin.release_model import main if __name__ == "__main__": diff --git a/tools/spm_to_vocab.py b/tools/spm_to_vocab.py index f2371727..ba7d734d 100644 --- a/tools/spm_to_vocab.py +++ b/tools/spm_to_vocab.py @@ -3,7 +3,7 @@ # counts) import sys import math -from mammoth.constants import DefaultTokens +from onmt.constants import DefaultTokens OMIT = (DefaultTokens.UNK, DefaultTokens.BOS, DefaultTokens.EOS) diff --git a/tools/test_rouge.py b/tools/test_rouge.py index 12ccc35d..436edab9 100644 --- a/tools/test_rouge.py +++ b/tools/test_rouge.py @@ -7,7 +7,7 @@ import sys import codecs -from mammoth.utils.logging import init_logger, logger +from onmt.utils.logging import init_logger, logger def eval_rouge(cand, ref): diff --git a/train.py b/train.py index 1648b083..1b03c9bc 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from mammoth.bin.train import main +from onmt.bin.train import main if __name__ == "__main__": diff --git a/translate.py b/translate.py index c27cbfac..5ca91336 100644 --- a/translate.py +++ b/translate.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from mammoth.bin.translate import main +from onmt.bin.translate import main if __name__ == "__main__":