diff --git a/examples/synthdata.template.yaml b/examples/synthdata.template.yaml new file mode 100644 index 00000000..f4a8d95e --- /dev/null +++ b/examples/synthdata.template.yaml @@ -0,0 +1,146 @@ +#################################### +# Meta-opts to control config_config +config_config: + # The synth data task key is given as both src_lang and tgt_lang + # We need to specify both, otherwise config-config would think cross-task data is available, even though it is not + src_path: "data/synthdata/train.{src_lang}-{tgt_lang}.src" + tgt_path: "data/synthdata/train.{src_lang}-{tgt_lang}.tgt" + valid_src_path: "data/synthdata/test.{src_lang}-{tgt_lang}.src" + valid_tgt_path: "data/synthdata/test.{src_lang}-{tgt_lang}.tgt" + # Only autoencoder tasks exist in this setup. We turn on the autoencoder, and validation for autoencoder tasks. + autoencoder: True + autoencoder_validation: True + # No distance matrix, because 1) we specify groups manually, and 2) also we don't use groupwise shared parameters + distance_matrix: null + n_groups: 3 + # No task weighting based on (temperature-adjusted) corpus size + use_weight: False + temperature: 0.5 + # Do not generate a translation config for zero-shot tasks + zero_shot: False + # Transforms for translation tasks. As only autoencoder tasks exist in this setup, leave this empty. + transforms: [] + # Transforms for autoencoder tasks. Because this toy task uses a small vocabulary, we don't apply sentencepiece. + ae_transforms: + - filtertoolong + # The encoder consists of one language-specific layer stack + enc_sharing_groups: + - LANGUAGE + # The decoder consists of one language-specific layer stack + dec_sharing_groups: + - LANGUAGE + # Defaults for the distributed training setup: number of nodes and how many GPUs each node has. + # Override these in the config_config command line arguments. + n_gpus_per_node: 1 + n_nodes: 1 + # If using the "prefix" transform, use_src_lang_token would add a source language token in addition to the target language token. + use_src_lang_token: False + # Manually specified sharing groups. + groups: + multi_query_associative_recall_kv6_q2: multi_query_associative_recall + multi_query_associative_recall_kv20_q4: multi_query_associative_recall + multi_query_associative_recall_kv12_q8: multi_query_associative_recall + copy_source: copy_source + distractor_separator_kv20_q4: copy_source + distractor_separator_kv12_q8: copy_source + reverse_source: copy_source + sort_source: copy_source + counting: counting + reverse_counting: counting + +# Paths to vocabulary files. Also specifies which languages to consider as source and target languages +src_vocab: + multi_query_associative_recall_kv6_q2: "data/synthdata/shared_vocab" + multi_query_associative_recall_kv20_q4: "data/synthdata/shared_vocab" + multi_query_associative_recall_kv12_q8: "data/synthdata/shared_vocab" + copy_source: "data/synthdata/shared_vocab" + distractor_separator_kv20_q4: "data/synthdata/shared_vocab" + distractor_separator_kv12_q8: "data/synthdata/shared_vocab" + reverse_source: "data/synthdata/shared_vocab" + sort_source: "data/synthdata/shared_vocab" + counting: "data/synthdata/shared_vocab" + reverse_counting: "data/synthdata/shared_vocab" +tgt_vocab: + multi_query_associative_recall_kv6_q2: "data/synthdata/shared_vocab" + multi_query_associative_recall_kv20_q4: "data/synthdata/shared_vocab" + multi_query_associative_recall_kv12_q8: "data/synthdata/shared_vocab" + copy_source: "data/synthdata/shared_vocab" + distractor_separator_kv20_q4: "data/synthdata/shared_vocab" + distractor_separator_kv12_q8: "data/synthdata/shared_vocab" + reverse_source: "data/synthdata/shared_vocab" + sort_source: "data/synthdata/shared_vocab" + counting: "data/synthdata/shared_vocab" + reverse_counting: "data/synthdata/shared_vocab" + +################################ +# Opts passed through to Mammoth + +# Prefix for model checkpoint files +save_model: models/synthdata + +# Maximum batch size for training, in tokens +batch_size: 8192 +batch_type: tokens +normalization: tokens +valid_batch_size: 4096 + +# Size of Transformer representations +model_dim: 256 +# The encoder consists of a single layerstack with 3 layers +enc_layers: [3] +# The decoder consists of a single layerstack with 2 layers +dec_layers: [2] +dropout: 0.1 +weight_decay: 0.05 +label_smoothing: 0.2 +# Stop training after this number of steps. Note that one step is accum_count minibatches. +train_steps: 50000 +# Perfom validation every X steps +valid_steps: 1000 +# Warmup takes X steps to reach maximum learning rate +warmup_steps: 3000 +# Report training statistics every X steps +report_every: 1000 +# Save a checkpoint every X steps +save_checkpoint_steps: 10000 +# Delete oldest checkpoints, leaving this many +keep_checkpoint: 3 +# Set optimizer to SGD +optim: sgd +# Adam parameters (do nothing, as we use SGD) +adam_beta1: 0.9 +adam_beta2: 0.998 +# Ramp up learning rate linearly for warmup_steps, then decay it linearly until train_steps +decay_method: linear_warmup +# Maximum learning rate +learning_rate: 0.00003 +# Clip the norm of the gradient of each distributed component, if it exceeds this value. +# Don't rely on max_grad_norm to save you from too high learning rate: +# as each component is clipped individually, renormalization does NOT preserve the direction of the global gradient. +max_grad_norm: 1.0 +# Random seed for replicability +seed: 3435 +# Only text is supported for now +model_type: text +#### filtertoolong transform parameters +src_seq_length: 200 +tgt_seq_length: 200 +#### denoising transform parameters (not used in this configuration) +mask_length: span-poisson +poisson_lambda: 3.0 +mask_ratio: 0.2 +replace_length: 1 +denoising_objective: bart + +####################################### +# Opts passed through to x-transformers +x_transformers_opts: + # Use flash attention + attn_flash: True + # The number of attention heads + heads: 16 + # Use rotary positional embeddings. + # This seems to be the only type of positional embedding that works properly in Mammoth. + rotary_pos_emb: True + # Tie the input and output embeddings of the decoder + tie_embedding: True diff --git a/examples/synthdata.yaml b/examples/synthdata.yaml deleted file mode 100644 index 45f05f5a..00000000 --- a/examples/synthdata.yaml +++ /dev/null @@ -1,100 +0,0 @@ -config_config: - # The synth data task key is given as both src_lang and tgt_lang - # We need to specify both, otherwise config-config would think cross-task data is available, even though it is not - src_path: "data/synthdata/train.{src_lang}-{tgt_lang}.src" - tgt_path: "data/synthdata/train.{src_lang}-{tgt_lang}.tgt" - valid_src_path: "data/synthdata/test.{src_lang}-{tgt_lang}.src" - valid_tgt_path: "data/synthdata/test.{src_lang}-{tgt_lang}.tgt" - # only autoencoder tasks exist in this setup - autoencoder: True - distance_matrix: null - n_groups: 3 - use_weight: False - temperature: 0.5 - zero_shot: False - # only autoencoder tasks exist in this setup - transforms: [] - ae_transforms: - - filtertoolong - enc_sharing_groups: - - LANGUAGE - dec_sharing_groups: - - LANGUAGE - n_gpus_per_node: 1 - n_nodes: 1 - use_src_lang_token: False - groups: - multi_query_associative_recall_kv6_q2: multi_query_associative_recall - multi_query_associative_recall_kv20_q4: multi_query_associative_recall - multi_query_associative_recall_kv12_q8: multi_query_associative_recall - copy_source: copy_source - distractor_separator_kv20_q4: copy_source - distractor_separator_kv12_q8: copy_source - reverse_source: copy_source - sort_source: copy_source - counting: counting - reverse_counting: counting - - -src_vocab: - multi_query_associative_recall_kv6_q2: "data/synthdata/shared_vocab" - multi_query_associative_recall_kv20_q4: "data/synthdata/shared_vocab" - multi_query_associative_recall_kv12_q8: "data/synthdata/shared_vocab" - copy_source: "data/synthdata/shared_vocab" - distractor_separator_kv20_q4: "data/synthdata/shared_vocab" - distractor_separator_kv12_q8: "data/synthdata/shared_vocab" - reverse_source: "data/synthdata/shared_vocab" - sort_source: "data/synthdata/shared_vocab" - counting: "data/synthdata/shared_vocab" - reverse_counting: "data/synthdata/shared_vocab" -tgt_vocab: - multi_query_associative_recall_kv6_q2: "data/synthdata/shared_vocab" - multi_query_associative_recall_kv20_q4: "data/synthdata/shared_vocab" - multi_query_associative_recall_kv12_q8: "data/synthdata/shared_vocab" - copy_source: "data/synthdata/shared_vocab" - distractor_separator_kv20_q4: "data/synthdata/shared_vocab" - distractor_separator_kv12_q8: "data/synthdata/shared_vocab" - reverse_source: "data/synthdata/shared_vocab" - sort_source: "data/synthdata/shared_vocab" - counting: "data/synthdata/shared_vocab" - reverse_counting: "data/synthdata/shared_vocab" - -save_model: models/synthdata - -batch_size: 4096 -batch_type: tokens -normalization: tokens -valid_batch_size: 4096 -model_dim: 128 -ff_mult: 4 -heads: 8 -enc_layers: [2] -dec_layers: [2] -dropout: 0.1 -weight_decay: 0.05 -label_smoothing: 0.1 -param_init: 0.0 -param_init_glorot: true -train_steps: 150000 -valid_steps: 1000000 -warmup_steps: 10000 -report_every: 100 -save_checkpoint_steps: 25000 -keep_checkpoint: 10 -optim: adafactor -adam_beta1: 0.9 -adam_beta2: 0.998 -decay_method: rsqrt -learning_rate: 0.01 -max_grad_norm: 0.0 -seed: 3435 -model_type: text -#### Filter -src_seq_length: 200 -tgt_seq_length: 200 -#### Bart -mask_length: span-poisson -poisson_lambda: 3.0 -mask_ratio: 0.2 -replace_length: 1 -denoising_objective: bart diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index d829ef62..d3760687 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -108,69 +108,37 @@ def get_transformer_wrapper_kwargs( return kwargs -def build_xcoder( +def build_adapters( side: Side, model_opts, - vocabs_dict: Dict[Tuple[str, str], Vocab], - device, task_queue_manager, single_task: Optional[str] = None, - token_embs: Optional[Dict[str, Vocab]] = None, -) -> StackXcoder: +) -> Optional[Dict[str, Adapter]]: """ - Build a StackXcoder for use as either Encoder or Decoder. - side: a Side enum from distributed components - model_opts: options - vocabs_dict: A dict mapping ('src'|'tgt', lang) to a Vocab. - device: torch.device - task_queue_manager: TaskQueueManager - single_task: if a task_id string is given, the built model contains only the components necessary for that task. - token_embs: to tie encoder and decoder embeddings, pass existing embeddings here. + Create AdapterLayer objects and Adapter objects """ - my_components: List[DistributedComponent] = task_queue_manager.get_my_distributed_components() - my_components = [ - component for component in my_components - if hasattr(component, 'side') and component.side == side - ] - distributed_xcoder_class: type + adapters_by_name: Optional[Dict[str, Adapter]] if side == Side.encoder: - distributed_xcoder_class = DistributedEncoderAttentionLayersBlock side_str = 'encoder' else: - distributed_xcoder_class = DistributedDecoderAttentionLayersBlock side_str = 'decoder' + my_components: List[DistributedComponent] = task_queue_manager.get_my_distributed_components() + my_side_specific_components = [ + component for component in my_components + if hasattr(component, 'side') and component.side == side + ] + if single_task: - my_components = [ - component for component in my_components + components_to_create = [ + component for component in my_side_specific_components if single_task in component.task_ids ] + else: + components_to_create = my_side_specific_components - # Create AdaptedAttentionLayers objects (an extension of an x_transformers.AttentionLayers block) - attention_layers_components = [ - component for component in my_components - if isinstance(component, distributed_xcoder_class) - ] - attention_layer_blocks = defaultdict(dict) - for component in attention_layers_components: - layer_stack_index = component.layer_stack_index - xcoder_id = component.xcoder_id - attention_layers_kwargs = get_attention_layers_kwargs( - side=side, - layer_stack_index=layer_stack_index, - xcoder_id=xcoder_id, - model_opts=model_opts, - ) - attention_layer_blocks[layer_stack_index][xcoder_id] = AdaptedAttentionLayers( - layer_stack_index=layer_stack_index, - xcoder_id=xcoder_id, - **attention_layers_kwargs - ) - - # Create AdapterLayer objects and Adapter objects - adapters_by_name: Optional[Dict[str, Adapter]] if uses_adapters(model_opts): adapter_components = [ - component for component in my_components + component for component in components_to_create if isinstance(component, DistributedAdapter) and component.side == side ] adapters_by_name = dict() @@ -203,25 +171,95 @@ def build_xcoder( ) else: raise ValueError(f'Unrecognized adapter_type {adapter_opts["adapter_type"]}') + layer_stack_index = adapter_params['layer_stack_index'] adapter = Adapter( adapter_group=component.adapter_group, sub_id=component.sub_id, + layer_stack_index=layer_stack_index, ) adapters_by_name[adapter.name] = adapter for layer_idx in adapter_params['layers']: adapter_layer = adapter_layer_func() adapter.add_layer(layer_idx, adapter_layer) - layer_stack_index = adapter_params['layer_stack_index'] - for xcoder_id, attention_layers in attention_layer_blocks[layer_stack_index].items(): + else: + adapters_by_name = None + return adapters_by_name + + +def build_xcoder( + side: Side, + model_opts, + vocabs_dict: Dict[Tuple[str, str], Vocab], + device, + task_queue_manager, + single_task: Optional[str] = None, + token_embs: Optional[Dict[str, Vocab]] = None, + adapters_by_name: Optional[Dict[str, Adapter]] = None, +) -> StackXcoder: + """ + Build a StackXcoder for use as either Encoder or Decoder. + side: a Side enum from distributed components + model_opts: options + vocabs_dict: A dict mapping ('src'|'tgt', lang) to a Vocab. + device: torch.device + task_queue_manager: TaskQueueManager + single_task: if a task_id string is given, the built model contains only the components necessary for that task. + token_embs: to tie encoder and decoder embeddings, pass existing embeddings here. + """ + my_components: List[DistributedComponent] = task_queue_manager.get_my_distributed_components() + my_side_specific_components = [ + component for component in my_components + if hasattr(component, 'side') and component.side == side + ] + + if single_task: + components_to_create = [ + component for component in my_side_specific_components + if single_task in component.task_ids + ] + else: + components_to_create = my_side_specific_components + + # Create AdaptedAttentionLayers objects (an extension of an x_transformers.AttentionLayers block) + distributed_xcoder_class: type + if side == Side.encoder: + distributed_xcoder_class = DistributedEncoderAttentionLayersBlock + elif side == Side.decoder: + distributed_xcoder_class = DistributedDecoderAttentionLayersBlock + else: + raise TypeError(type(side)) + attention_layers_components = [ + component for component in components_to_create + if isinstance(component, distributed_xcoder_class) + ] + + attention_layer_blocks: Dict[int, Dict[str, AdaptedAttentionLayers]] = defaultdict(dict) + for component in attention_layers_components: + layer_stack_index = component.layer_stack_index + xcoder_id = component.xcoder_id + attention_layers_kwargs = get_attention_layers_kwargs( + side=side, + layer_stack_index=layer_stack_index, + xcoder_id=xcoder_id, + model_opts=model_opts, + ) + attention_layer_blocks[layer_stack_index][xcoder_id] = AdaptedAttentionLayers( + layer_stack_index=layer_stack_index, + xcoder_id=xcoder_id, + **attention_layers_kwargs + ) + + # Add pre-created Adapters to the AdaptedAttentionLayers objects + if adapters_by_name is not None: + for adapter_name, adapter in adapters_by_name.items(): + for xcoder_id, attention_layers in attention_layer_blocks[adapter.layer_stack_index].items(): # TODO: allow limiting which xcoder_ids get the adapter? - logger.info(f'adding {adapter.name} to {layer_stack_index}:{xcoder_id}:{component.sub_id}') + logger.info(f'adding {adapter.name} to {adapter.layer_stack_index}:{xcoder_id}:{adapter.sub_id}') try: attention_layers.add_adapter(adapter) except Exception as e: logger.error(repr(attention_layers)) raise e - else: - adapters_by_name = None # Create TokenEmbedding objects l2norm_embed = False @@ -329,6 +367,12 @@ def build_model( device = torch.device("cpu") logger.info(device) + enc_adapters_by_name: Optional[Dict[str, Adapter]] = build_adapters( + side=Side.encoder, + model_opts=model_opts, + task_queue_manager=task_queue_manager, + single_task=single_task, + ) encoder = build_xcoder( side=Side.encoder, model_opts=model_opts, @@ -336,6 +380,15 @@ def build_model( device=device, task_queue_manager=task_queue_manager, single_task=single_task, + adapters_by_name=enc_adapters_by_name, + ) + # TODO: to tie embeddings between encoder and decoder, + # take the token_embs from the encoder and pass them in the next build_xcoder call + dec_adapters_by_name: Optional[Dict[str, Adapter]] = build_adapters( + side=Side.decoder, + model_opts=model_opts, + task_queue_manager=task_queue_manager, + single_task=single_task, ) decoder = build_xcoder( side=Side.decoder, @@ -344,6 +397,7 @@ def build_model( device=device, task_queue_manager=task_queue_manager, single_task=single_task, + adapters_by_name=dec_adapters_by_name, ) attention_bridge = build_attention_bridge(model_opts) model = NMTModel( @@ -353,7 +407,12 @@ def build_model( ) model.to(device) - # logger.info(model) + if opts.log_model_structure: + logger.info(model) + for component in task_queue_manager.get_my_distributed_components(): + logger.info(component) + for name, p in model.named_parameters(): + logger.info(f'{p.requires_grad} {name}') logger.info('Building model - done!') return model diff --git a/mammoth/models/model.py b/mammoth/models/model.py index 912945fc..3c26a885 100644 --- a/mammoth/models/model.py +++ b/mammoth/models/model.py @@ -48,8 +48,8 @@ class NMTModel(BaseModel): Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic encoder + decoder model. Args: - encoder (mammoth.encoders.EncoderBase): an encoder object - decoder (mammoth.decoders.DecoderBase): a decoder object + encoder (mammoth.modules.layer_stack.StackXcoder): an encoder object + decoder (mammoth.modules.layer_stack.StackXcoder): a decoder object """ def __init__(self, encoder, decoder, attention_bridge): diff --git a/mammoth/modules/adapters.py b/mammoth/modules/adapters.py index ae819dc7..9e8a067c 100644 --- a/mammoth/modules/adapters.py +++ b/mammoth/modules/adapters.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from collections import defaultdict -from typing import Union, Set, Dict +from typing import Union, Set, Dict, Tuple, Optional from functools import partial from x_transformers.x_transformers import ( @@ -69,7 +69,8 @@ def __init__( self.A = nn.Parameter(torch.randn(dim, r)) self.B = nn.Parameter(torch.zeros(r, dim_out)) - self.wrapped_base_layer = None + # the type is a hack to avoid registering the wrapped base layer as a child + self._wrapped_base_layer: Optional[Tuple[nn.Module]] = None @property def is_wrapper(self): @@ -82,7 +83,7 @@ def apply(self, tmp_layer_types, tmp_layer_structs, tmp_layer_dropouts): return tmp_layer_types, new_layer_structs, tmp_layer_dropouts def wrap(self, base_layer): - self.wrapped_base_layer = base_layer + self._wrapped_base_layer = (base_layer,) return self @property @@ -90,7 +91,11 @@ def weight(self): return (self.A @ self.B) * self.scale def forward(self, x): - return (x @ self.weight) + self.wrapped_base_layer.forward(x) + if self._wrapped_base_layer is None: + raise Exception('LoraAdapterLayer.wrap was not called before forward') + wrapped_base_layer = self._wrapped_base_layer[0] + self._wrapped_base_layer = None + return (x @ self.weight) + wrapped_base_layer.forward(x) AdapterLayer = Union[FeedForwardAdapterLayer, LoraAdapterLayer] @@ -102,10 +107,11 @@ class Adapter(nn.Module): together with layer indices for injecting into the base network. """ - def __init__(self, adapter_group: str, sub_id: str): + def __init__(self, adapter_group: str, sub_id: str, layer_stack_index: int): super().__init__() self.adapter_group = adapter_group self.sub_id = sub_id + self.layer_stack_index = layer_stack_index self.name = self._name(adapter_group, sub_id) # mapping layer_idx -> ModuleList of AdapterLayer to inject at that layer self.adapter_layers = nn.ModuleDict() diff --git a/mammoth/opts.py b/mammoth/opts.py index 9d754032..520c0602 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -45,6 +45,12 @@ def _add_logging_opts(parser, is_train=True): if is_train else 'Print scores and predictions for each sentence', ) + group.add( + '--log_model_structure', + '-log_model_structure', + action="store_true", + help='Print the entire model structure when building the model. Verbose, but useful for debugging.' + ) if is_train: group.add('--report_every', '-report_every', type=int, default=50, help="Print stats at this interval.") @@ -225,7 +231,7 @@ def model_opts(parser): '-model_dim', type=int, default=-1, - help="Size of rnn hidden states.", + help="Size of Transformer representations.", ) group.add( @@ -335,6 +341,22 @@ def model_opts(parser): help='Number of heads for transformer self-attention. ' ' Semi-obsolete: not used for x-transformers, only used for some attention bridge configuations.' ) + group.add( + '--dropout', + '-dropout', + type=float, + default=[0.3], + nargs='+', + help="Dropout probability; Legacy: applied in the attention bridge", + ) + group.add( + '--attention_dropout', + '-attention_dropout', + type=float, + default=[0.1], + nargs='+', + help="Attention Dropout probability; Legacy: applied in the attention bridge", + ) # adapter options are in a dict "adapters", and in the corpus options group = parser.add_argument_group("Adapters") @@ -420,7 +442,7 @@ def _add_train_general_opts(parser): '-param_init', type=float, default=0.1, - help="Parameters are initialized over uniform distribution " + help="Legacy opt for attention bridge. Parameters are initialized over uniform distribution " "with support (-param_init, param_init). " "Use 0 to not use initialization", ) @@ -428,7 +450,7 @@ def _add_train_general_opts(parser): '--param_init_glorot', '-param_init_glorot', action='store_true', - help="Init parameters with xavier_uniform. Required for transformer.", + help="Legacy opt for attention bridge. Init parameters with xavier_uniform.", ) group.add( @@ -549,23 +571,6 @@ def _add_train_general_opts(parser): default=0.0, help="L2 penalty (weight decay) regularizer", ) - # FIXME, mentions LSTM - group.add( - '--dropout', - '-dropout', - type=float, - default=[0.3], - nargs='+', - help="Dropout probability; applied in LSTM stacks.", - ) - group.add( - '--attention_dropout', - '-attention_dropout', - type=float, - default=[0.1], - nargs='+', - help="Attention Dropout probability.", - ) group.add( '--dropout_steps', '-dropout_steps', type=int, nargs='+', default=[0], help="Steps at which dropout changes." ) @@ -905,7 +910,7 @@ def translate_opts(parser, dynamic=False): _add_logging_opts(parser, is_train=False) group = parser.add_argument_group('Efficiency') - group.add('--batch_size', '-batch_size', type=int, default=300, help='Batch size') + group.add('--batch_size', '-batch_size', type=int, default=200, help='Batch size') group.add( '--batch_type', '-batch_type', @@ -913,7 +918,7 @@ def translate_opts(parser, dynamic=False): choices=["sents", "tokens"], help="Batch grouping for batch_size. Standard is tokens (max of src and tgt). Sents is unimplemented.", ) - group.add('--gpu', '-gpu', type=int, default=-1, help="Device to run on") + group.add('--gpu_rank', '-gpu_rank', type=int, default=-1, help="Device to run on") group.add( "--output_model", diff --git a/mammoth/translate/translation_server.py b/mammoth/translate/translation_server.py index d96bc784..541ea262 100644 --- a/mammoth/translate/translation_server.py +++ b/mammoth/translate/translation_server.py @@ -119,7 +119,7 @@ def setdefault_if_exists_must_match(obj, name, value): onmt_for_translator = { "device": "cuda" if opts.cuda else "cpu", - "device_index": opts.gpu if opts.cuda else 0, + "device_index": opts.gpu_rank if opts.cuda else 0, } for name, value in onmt_for_translator.items(): setdefault_if_exists_must_match(ct2_translator_args, name, value) @@ -417,7 +417,7 @@ def parse_opt(self, opts): ArgumentParser.validate_prepare_opts(opts) ArgumentParser.validate_translate_opts(opts) ArgumentParser.validate_translate_opts_dynamic(opts) - opts.cuda = opts.gpu > -1 + opts.cuda = opts.gpu_rank > -1 sys.argv = prec_argv return opts @@ -727,7 +727,7 @@ def to_gpu(self): if isinstance(self.translator, CTranslate2Translator): self.translator.to_gpu() else: - torch.cuda.set_device(self.opts.gpu) + torch.cuda.set_device(self.opts.gpu_rank) self.translator.model.cuda() def maybe_preprocess(self, sequence): diff --git a/mammoth/translate/translator.py b/mammoth/translate/translator.py index 295ade6a..9ac54abc 100644 --- a/mammoth/translate/translator.py +++ b/mammoth/translate/translator.py @@ -28,7 +28,7 @@ def build_translator(opts, task_queue_manager, task, report_score=True, logger=N if out_file is None: outdir = os.path.dirname(opts.output) if outdir and not os.path.isdir(outdir): - warnings.warning(f'output file directory "{outdir}" does not exist... creating it.') + warnings.warn(f'output file directory "{outdir}" does not exist... creating it.') os.makedirs(os.path.dirname(opts.output), exist_ok=True) out_file = codecs.open(opts.output, "w+", "utf-8") @@ -308,7 +308,7 @@ def from_opts( vocabs, opts.src, tgt_file_path=opts.tgt, - gpu=opts.gpu, + gpu=opts.gpu_rank, n_best=opts.n_best, min_length=opts.min_length, max_length=opts.max_length, @@ -831,6 +831,8 @@ def _translate_batch_with_strategy(self, batch, src_vocabs, decode_strategy): task_id=metadata.corpus_id, adapter_ids=metadata.decoder_adapter_ids, ) + active_encoder.to(self._device) + active_decoder.to(self._device) # (2) Run the encoder on the src encoder_output, src_mask = self._run_encoder(active_encoder, batch) diff --git a/mammoth/utils/misc.py b/mammoth/utils/misc.py index ac8caf27..edc23447 100644 --- a/mammoth/utils/misc.py +++ b/mammoth/utils/misc.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- -import torch -import random +import gzip import inspect import numpy as np -from itertools import islice, repeat -from io import StringIO import os +import random +import torch +from io import StringIO +from itertools import islice, repeat def check_path(path, exist_ok=False, log=print): @@ -33,7 +34,11 @@ def split_corpus(path, shard_size, default=None): def _split_corpus(path, shard_size): """Yield io's with `shard_size` lines each.""" # FIXME: this is a horrible, ugly kludge - with open(path, "rt") as f: + if path.endswith('.gz'): + open_func = gzip.open + else: + open_func = open + with open_func(path, "rt") as f: if shard_size <= 0: yield f else: diff --git a/mammoth/utils/model_saver.py b/mammoth/utils/model_saver.py index d75cc6c4..cf8eb301 100644 --- a/mammoth/utils/model_saver.py +++ b/mammoth/utils/model_saver.py @@ -284,7 +284,7 @@ def _save(self, step, model, data_state, task_queue_manager): module_state_dicts['frame']['data_state'] = data_state for key, state_dict in module_state_dicts.items(): - # The state_dicts across different devices only contain one copy of each module: + # The exploded state_dicts across different devices only contain one copy of each module: # on the lowest ranked device having that module. # There is no race condition. checkpoint_path = f'{self.base_path}_step_{step}_{key}.pt' diff --git a/mammoth/utils/parse.py b/mammoth/utils/parse.py index eacb44f2..38f4690b 100644 --- a/mammoth/utils/parse.py +++ b/mammoth/utils/parse.py @@ -13,6 +13,15 @@ RE_SRC_TGT = re.compile(r'[^-]+-[^-]+') +def yaml_or_dict(val, name): + if isinstance(val, str): + return yaml.safe_load(val) + elif isinstance(val, dict): + return val + else: + raise TypeError(f'{name} {type(val)}') + + class DataOptsCheckerMixin(object): """Checker with methods for validate data related options.""" @@ -27,7 +36,7 @@ def _validate_adapters(cls, opts): """Parse corpora specified in data field of YAML file.""" if not opts.adapters: return - adapter_opts = yaml.safe_load(opts.adapters) + adapter_opts = yaml_or_dict(opts.adapters, name='opts.adapters') # TODO: validate adapter opts opts.adapters = adapter_opts @@ -37,7 +46,7 @@ def _validate_tasks(cls, opts): default_transforms = opts.transforms if len(default_transforms) != 0: logger.info(f"Default transforms: {default_transforms}.") - corpora = yaml.safe_load(opts.tasks) + corpora = yaml_or_dict(opts.tasks, name='opts.tasks') logger.info("Parsing corpora") n_without_node_gpu = 0 for cname, corpus in corpora.items(): @@ -139,11 +148,11 @@ def _validate_tasks(cls, opts): logger.info(f"Parsed {len(corpora)} corpora from -data.") opts.tasks = corpora - src_vocab = yaml.safe_load(opts.src_vocab) + src_vocab = yaml_or_dict(opts.src_vocab, name="opts.src_vocab") logger.info(f"Parsed {len(src_vocab)} vocabs from -src_vocab.") opts.src_vocab = src_vocab - tgt_vocab = yaml.safe_load(opts.tgt_vocab) + tgt_vocab = yaml_or_dict(opts.tgt_vocab, name="opts.tgt_vocab") logger.info(f"Parsed {len(tgt_vocab)} vocabs from -tgt_vocab.") opts.tgt_vocab = tgt_vocab @@ -176,7 +185,7 @@ def _validate_fields_opts(cls, opts): if cname != CorpusName.VALID and corpus["src_feats"] is not None: assert opts.src_feats_vocab, "-src_feats_vocab is required if using source features." if isinstance(opts.src_feats_vocab, str): - opts.src_feats_vocab = yaml.safe_load(opts.src_feats_vocab) + opts.src_feats_vocab = yaml_or_dict(opts.src_feats_vocab, name="opts.src_feats_vocab") for feature in corpus["src_feats"].keys(): assert feature in opts.src_feats_vocab, f"No vocab file set for feature {feature}" @@ -258,7 +267,7 @@ def validate_x_transformers_opts(cls, opts): if not opts.x_transformers_opts: opts.x_transformers_opts = dict() return - opts_dict = yaml.safe_load(opts.x_transformers_opts) + opts_dict = yaml_or_dict(opts.x_transformers_opts, name="opts.x_transformers_opts") for overwritten_key in ( 'dim', 'depth',