Skip to content

Commit

Permalink
Code review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Dec 9, 2024
1 parent a98e9f8 commit 56ac097
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 37 deletions.
42 changes: 27 additions & 15 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,31 @@ def build_adapters(
task_queue_manager,
single_task: Optional[str] = None,
) -> Optional[Dict[str, Adapter]]:
# Create AdapterLayer objects and Adapter objects
"""
Create AdapterLayer objects and Adapter objects
"""
adapters_by_name: Optional[Dict[str, Adapter]]
if side == Side.encoder:
side_str = 'encoder'
else:
side_str = 'decoder'
my_components: List[DistributedComponent] = task_queue_manager.get_my_distributed_components()
my_components = [
my_side_specific_components = [
component for component in my_components
if hasattr(component, 'side') and component.side == side
]

if single_task:
my_components = [
component for component in my_components
components_to_create = [
component for component in my_side_specific_components
if single_task in component.task_ids
]
else:
components_to_create = my_side_specific_components

if uses_adapters(model_opts):
adapter_components = [
component for component in my_components
component for component in components_to_create
if isinstance(component, DistributedAdapter) and component.side == side
]
adapters_by_name = dict()
Expand Down Expand Up @@ -188,26 +194,32 @@ def build_xcoder(
token_embs: to tie encoder and decoder embeddings, pass existing embeddings here.
"""
my_components: List[DistributedComponent] = task_queue_manager.get_my_distributed_components()
my_components = [
my_side_specific_components = [
component for component in my_components
if hasattr(component, 'side') and component.side == side
]
distributed_xcoder_class: type
if side == Side.encoder:
distributed_xcoder_class = DistributedEncoderAttentionLayersBlock
else:
distributed_xcoder_class = DistributedDecoderAttentionLayersBlock

if single_task:
my_components = [
component for component in my_components
components_to_create = [
component for component in my_side_specific_components
if single_task in component.task_ids
]
else:
components_to_create = my_side_specific_components

# Create AdaptedAttentionLayers objects (an extension of an x_transformers.AttentionLayers block)
distributed_xcoder_class: type
if side == Side.encoder:
distributed_xcoder_class = DistributedEncoderAttentionLayersBlock
elif side == Side.decoder:
distributed_xcoder_class = DistributedDecoderAttentionLayersBlock
else:
raise TypeError(type(side))
attention_layers_components = [
component for component in my_components
component for component in components_to_create
if isinstance(component, distributed_xcoder_class)
]

attention_layer_blocks: Dict[int, Dict[str, AdaptedAttentionLayers]] = defaultdict(dict)
for component in attention_layers_components:
layer_stack_index = component.layer_stack_index
Expand Down Expand Up @@ -385,7 +397,7 @@ def build_model(
for component in task_queue_manager.get_my_distributed_components():
logger.info(component)
for name, p in model.named_parameters():
print(f'{p.requires_grad} {name}')
logger.info(f'{p.requires_grad} {name}')
logger.info('Building model - done!')
return model

Expand Down
33 changes: 16 additions & 17 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,22 @@ def model_opts(parser):
help='Number of heads for transformer self-attention. '
' Semi-obsolete: not used for x-transformers, only used for some attention bridge configuations.'
)
group.add(
'--dropout',
'-dropout',
type=float,
default=[0.3],
nargs='+',
help="Dropout probability; Legacy: applied in the attention bridge",
)
group.add(
'--attention_dropout',
'-attention_dropout',
type=float,
default=[0.1],
nargs='+',
help="Attention Dropout probability; Legacy: applied in the attention bridge",
)

# adapter options are in a dict "adapters", and in the corpus options
group = parser.add_argument_group("Adapters")
Expand Down Expand Up @@ -547,23 +563,6 @@ def _add_train_general_opts(parser):
default=0.0,
help="L2 penalty (weight decay) regularizer",
)
# FIXME, mentions LSTM
group.add(
'--dropout',
'-dropout',
type=float,
default=[0.3],
nargs='+',
help="Dropout probability; applied in LSTM stacks. (Probably legacy?)",
)
group.add(
'--attention_dropout',
'-attention_dropout',
type=float,
default=[0.1],
nargs='+',
help="Attention Dropout probability.",
)
group.add(
'--dropout_steps', '-dropout_steps', type=int, nargs='+', default=[0], help="Steps at which dropout changes."
)
Expand Down
10 changes: 5 additions & 5 deletions mammoth/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _validate_tasks(cls, opts):
default_transforms = opts.transforms
if len(default_transforms) != 0:
logger.info(f"Default transforms: {default_transforms}.")
corpora = yaml.safe_load(opts.tasks)
corpora = yaml_or_dict(opts.tasks, name='opts.tasks')
logger.info("Parsing corpora")
n_without_node_gpu = 0
for cname, corpus in corpora.items():
Expand Down Expand Up @@ -148,11 +148,11 @@ def _validate_tasks(cls, opts):
logger.info(f"Parsed {len(corpora)} corpora from -data.")
opts.tasks = corpora

src_vocab = yaml.safe_load(opts.src_vocab)
src_vocab = yaml_or_dict(opts.src_vocab, name="opts.src_vocab")
logger.info(f"Parsed {len(src_vocab)} vocabs from -src_vocab.")
opts.src_vocab = src_vocab

tgt_vocab = yaml.safe_load(opts.tgt_vocab)
tgt_vocab = yaml_or_dict(opts.tgt_vocab, name="opts.tgt_vocab")
logger.info(f"Parsed {len(tgt_vocab)} vocabs from -tgt_vocab.")
opts.tgt_vocab = tgt_vocab

Expand Down Expand Up @@ -185,7 +185,7 @@ def _validate_fields_opts(cls, opts):
if cname != CorpusName.VALID and corpus["src_feats"] is not None:
assert opts.src_feats_vocab, "-src_feats_vocab is required if using source features."
if isinstance(opts.src_feats_vocab, str):
opts.src_feats_vocab = yaml.safe_load(opts.src_feats_vocab)
opts.src_feats_vocab = yaml_or_dict(opts.src_feats_vocab, name="opts.src_feats_vocab")

for feature in corpus["src_feats"].keys():
assert feature in opts.src_feats_vocab, f"No vocab file set for feature {feature}"
Expand Down Expand Up @@ -267,7 +267,7 @@ def validate_x_transformers_opts(cls, opts):
if not opts.x_transformers_opts:
opts.x_transformers_opts = dict()
return
opts_dict = yaml.safe_load(opts.x_transformers_opts)
opts_dict = yaml_or_dict(opts.x_transformers_opts, name="opts.x_transformers_opts")
for overwritten_key in (
'dim',
'depth',
Expand Down

0 comments on commit 56ac097

Please sign in to comment.