From 56ac0970bd11fa823a1d720991540f627a38f738 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 9 Dec 2024 12:39:49 +0200 Subject: [PATCH] Code review fixes --- mammoth/model_builder.py | 42 ++++++++++++++++++++++++++-------------- mammoth/opts.py | 33 +++++++++++++++---------------- mammoth/utils/parse.py | 10 +++++----- 3 files changed, 48 insertions(+), 37 deletions(-) diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index 101f7d36..190c136f 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -101,25 +101,31 @@ def build_adapters( task_queue_manager, single_task: Optional[str] = None, ) -> Optional[Dict[str, Adapter]]: - # Create AdapterLayer objects and Adapter objects + """ + Create AdapterLayer objects and Adapter objects + """ adapters_by_name: Optional[Dict[str, Adapter]] if side == Side.encoder: side_str = 'encoder' else: side_str = 'decoder' my_components: List[DistributedComponent] = task_queue_manager.get_my_distributed_components() - my_components = [ + my_side_specific_components = [ component for component in my_components if hasattr(component, 'side') and component.side == side ] + if single_task: - my_components = [ - component for component in my_components + components_to_create = [ + component for component in my_side_specific_components if single_task in component.task_ids ] + else: + components_to_create = my_side_specific_components + if uses_adapters(model_opts): adapter_components = [ - component for component in my_components + component for component in components_to_create if isinstance(component, DistributedAdapter) and component.side == side ] adapters_by_name = dict() @@ -188,26 +194,32 @@ def build_xcoder( token_embs: to tie encoder and decoder embeddings, pass existing embeddings here. """ my_components: List[DistributedComponent] = task_queue_manager.get_my_distributed_components() - my_components = [ + my_side_specific_components = [ component for component in my_components if hasattr(component, 'side') and component.side == side ] - distributed_xcoder_class: type - if side == Side.encoder: - distributed_xcoder_class = DistributedEncoderAttentionLayersBlock - else: - distributed_xcoder_class = DistributedDecoderAttentionLayersBlock + if single_task: - my_components = [ - component for component in my_components + components_to_create = [ + component for component in my_side_specific_components if single_task in component.task_ids ] + else: + components_to_create = my_side_specific_components # Create AdaptedAttentionLayers objects (an extension of an x_transformers.AttentionLayers block) + distributed_xcoder_class: type + if side == Side.encoder: + distributed_xcoder_class = DistributedEncoderAttentionLayersBlock + elif side == Side.decoder: + distributed_xcoder_class = DistributedDecoderAttentionLayersBlock + else: + raise TypeError(type(side)) attention_layers_components = [ - component for component in my_components + component for component in components_to_create if isinstance(component, distributed_xcoder_class) ] + attention_layer_blocks: Dict[int, Dict[str, AdaptedAttentionLayers]] = defaultdict(dict) for component in attention_layers_components: layer_stack_index = component.layer_stack_index @@ -385,7 +397,7 @@ def build_model( for component in task_queue_manager.get_my_distributed_components(): logger.info(component) for name, p in model.named_parameters(): - print(f'{p.requires_grad} {name}') + logger.info(f'{p.requires_grad} {name}') logger.info('Building model - done!') return model diff --git a/mammoth/opts.py b/mammoth/opts.py index 1ff4524b..04c9b3de 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -333,6 +333,22 @@ def model_opts(parser): help='Number of heads for transformer self-attention. ' ' Semi-obsolete: not used for x-transformers, only used for some attention bridge configuations.' ) + group.add( + '--dropout', + '-dropout', + type=float, + default=[0.3], + nargs='+', + help="Dropout probability; Legacy: applied in the attention bridge", + ) + group.add( + '--attention_dropout', + '-attention_dropout', + type=float, + default=[0.1], + nargs='+', + help="Attention Dropout probability; Legacy: applied in the attention bridge", + ) # adapter options are in a dict "adapters", and in the corpus options group = parser.add_argument_group("Adapters") @@ -547,23 +563,6 @@ def _add_train_general_opts(parser): default=0.0, help="L2 penalty (weight decay) regularizer", ) - # FIXME, mentions LSTM - group.add( - '--dropout', - '-dropout', - type=float, - default=[0.3], - nargs='+', - help="Dropout probability; applied in LSTM stacks. (Probably legacy?)", - ) - group.add( - '--attention_dropout', - '-attention_dropout', - type=float, - default=[0.1], - nargs='+', - help="Attention Dropout probability.", - ) group.add( '--dropout_steps', '-dropout_steps', type=int, nargs='+', default=[0], help="Steps at which dropout changes." ) diff --git a/mammoth/utils/parse.py b/mammoth/utils/parse.py index a864d02c..38f4690b 100644 --- a/mammoth/utils/parse.py +++ b/mammoth/utils/parse.py @@ -46,7 +46,7 @@ def _validate_tasks(cls, opts): default_transforms = opts.transforms if len(default_transforms) != 0: logger.info(f"Default transforms: {default_transforms}.") - corpora = yaml.safe_load(opts.tasks) + corpora = yaml_or_dict(opts.tasks, name='opts.tasks') logger.info("Parsing corpora") n_without_node_gpu = 0 for cname, corpus in corpora.items(): @@ -148,11 +148,11 @@ def _validate_tasks(cls, opts): logger.info(f"Parsed {len(corpora)} corpora from -data.") opts.tasks = corpora - src_vocab = yaml.safe_load(opts.src_vocab) + src_vocab = yaml_or_dict(opts.src_vocab, name="opts.src_vocab") logger.info(f"Parsed {len(src_vocab)} vocabs from -src_vocab.") opts.src_vocab = src_vocab - tgt_vocab = yaml.safe_load(opts.tgt_vocab) + tgt_vocab = yaml_or_dict(opts.tgt_vocab, name="opts.tgt_vocab") logger.info(f"Parsed {len(tgt_vocab)} vocabs from -tgt_vocab.") opts.tgt_vocab = tgt_vocab @@ -185,7 +185,7 @@ def _validate_fields_opts(cls, opts): if cname != CorpusName.VALID and corpus["src_feats"] is not None: assert opts.src_feats_vocab, "-src_feats_vocab is required if using source features." if isinstance(opts.src_feats_vocab, str): - opts.src_feats_vocab = yaml.safe_load(opts.src_feats_vocab) + opts.src_feats_vocab = yaml_or_dict(opts.src_feats_vocab, name="opts.src_feats_vocab") for feature in corpus["src_feats"].keys(): assert feature in opts.src_feats_vocab, f"No vocab file set for feature {feature}" @@ -267,7 +267,7 @@ def validate_x_transformers_opts(cls, opts): if not opts.x_transformers_opts: opts.x_transformers_opts = dict() return - opts_dict = yaml.safe_load(opts.x_transformers_opts) + opts_dict = yaml_or_dict(opts.x_transformers_opts, name="opts.x_transformers_opts") for overwritten_key in ( 'dim', 'depth',