Skip to content

Commit

Permalink
Fix bugs in adapter implementation
Browse files Browse the repository at this point in the history
- Missing top-level adapter_type opt replaced moved inside the adapter
  definition. This allows different types of adapter to coexist.
- Remove a level of looping to avoid redundant adapters.
- Accept an empty list to specify zero adapters for enc/dec (in addition
  to the empty dict which is more correct)
- Store adapters by name in StackXcoder
  • Loading branch information
Waino committed Nov 18, 2024
1 parent dd4e1ff commit 1cd30ef
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 48 deletions.
6 changes: 3 additions & 3 deletions mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, A
module = self.get_module(model)
destination: Dict[str, Any] = OrderedDict()
for name, sub_module in module._modules.items():
if name.endswith('attn_layers'):
if name in {'attn_layers', 'token_emb'}:
# stored separately
continue
sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
Expand Down Expand Up @@ -240,9 +240,9 @@ def get_name(self) -> str:

def get_module(self, model: NMTModel) -> nn.Module:
if self.side == Side.encoder:
model.encoder.get_adapter(self.adapter_group, self.sub_id)
return model.encoder.get_adapter(f'adapter_{self.adapter_group}_{self.sub_id}')
else:
model.decoder.get_adapter(self.adapter_group, self.sub_id)
return model.decoder.get_adapter(f'adapter_{self.adapter_group}_{self.sub_id}')


@dataclass
Expand Down
39 changes: 25 additions & 14 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,13 @@ def build_xcoder(
)

# Create AdapterLayer objects and Adapter objects
adapters_by_name: Optional[Dict[str, Adapter]]
if uses_adapters(model_opts):
adapter_components = [
component for component in my_components
if isinstance(component, DistributedAdapter) and component.side == side
]
adapters_by_name = dict()
adapter_params_by_group = dict()
for adapter_group, adapter_opts in model_opts.adapters[side_str].items():
adapter_params_by_group[adapter_group] = {
Expand All @@ -169,13 +171,13 @@ def build_xcoder(
}
for component in adapter_components:
adapter_params = adapter_params_by_group[component.adapter_group]
if model_opts.adapter_type.lower() == 'lora':
if adapter_opts['adapter_type'].lower() == 'lora':
adapter_layer_func = partial(
LoraAdapterLayer,
dim=model_opts.model_dim,
r=adapter_params['hidden_dim'],
)
elif model_opts.adapter_type.lower() == 'ff':
elif adapter_opts['adapter_type'].lower() == 'ff':
mult = adapter_params['hidden_dim'] / model_opts.model_dim
# TODO: make norm locations and glu configurable
adapter_layer_func = partial(
Expand All @@ -187,18 +189,26 @@ def build_xcoder(
glu=True,
)
else:
raise ValueError(f'Unrecognized adapter_type {model_opts.adapter_type}')
for sub_id in adapter_params['sub_ids']:
for layer_idx in adapter_params['layers']:
adapter_layer = adapter_layer_func()
adapter = Adapter(
adapter_group=component.adapter_group,
sub_id=sub_id,
)
adapter.add_layer(layer_idx, adapter_layer)
layer_stack_index = adapter_params['layer_stack_index']
for attention_layers in attention_layer_blocks[layer_stack_index]:
attention_layers.add_adapter(adapter)
raise ValueError(f'Unrecognized adapter_type {adapter_opts["adapter_type"]}')
adapter = Adapter(
adapter_group=component.adapter_group,
sub_id=component.sub_id,
)
adapters_by_name[adapter.name] = adapter
for layer_idx in adapter_params['layers']:
adapter_layer = adapter_layer_func()
adapter.add_layer(layer_idx, adapter_layer)
layer_stack_index = adapter_params['layer_stack_index']
for xcoder_id, attention_layers in attention_layer_blocks[layer_stack_index].items():
# TODO: allow limiting which xcoder_ids get the adapter?
logger.info(f'adding {adapter.name} to {layer_stack_index}:{xcoder_id}:{component.sub_id}')
try:
attention_layers.add_adapter(adapter)
except Exception as e:
logger.error(repr(attention_layers))
raise e
else:
adapters_by_name = None

# Create TokenEmbedding objects
l2norm_embed = False
Expand Down Expand Up @@ -256,6 +266,7 @@ def build_xcoder(
transformer_wrappers=transformer_wrappers,
attention_layer_blocks=attention_layer_blocks,
token_embs=token_embs,
adapters=adapters_by_name,
)
return stack_xcoder

Expand Down
7 changes: 6 additions & 1 deletion mammoth/modules/layer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from x_transformers import TransformerWrapper
from x_transformers.x_transformers import LayerIntermediates, TokenEmbedding

from mammoth.modules.adapters import AdaptedAttentionLayers
from mammoth.modules.adapters import AdaptedAttentionLayers, Adapter


class AdaptedAttentionLayersStack(nn.Module):
Expand Down Expand Up @@ -65,11 +65,13 @@ def __init__(
transformer_wrappers: Dict[str, TransformerWrapper],
attention_layer_blocks: Dict[int, Dict[str, AdaptedAttentionLayers]],
token_embs: Dict[str, TokenEmbedding],
adapters: Optional[Dict[str, Adapter]],
):
super().__init__(transformer_wrappers)
self.attention_layers_by_xcoder_id: Dict[int, Dict[str, AdaptedAttentionLayers]] = attention_layer_blocks
self.token_embs: Dict[str, TokenEmbedding] = token_embs
self.active_task: Optional[str] = None
self.adapters = adapters

# TransformerWrapper wraps an AttentionLayers in embeddings and some other functionality.
# We use one TransformerWrapper per task.
Expand All @@ -96,4 +98,7 @@ def get_embedding_by_task_id(self, task_id):
def get_embedding_by_lang(self, lang):
return self.token_embs[lang]

def get_adapter(self, adapter_name):
return self.adapters[adapter_name]

# Lack of forward is intentional: call forward on the return value of activate
62 changes: 32 additions & 30 deletions tools/config_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,36 +618,38 @@ def adapter_config(opts):
if 'adapters' not in task_config:
task_config['adapters'] = {'encoder': [], 'decoder': []}
# TODO: refactor and add support for {SRC|TGT}_{LANGUAGE|GROUP} also to adapters
for adapter_name, adapter_config in sorted(encoder_adapters.items()):
if adapter_config['ids'] == 'LANGUAGE':
adapter_config['ids'] = list(src_langs)
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['encoder'].append([adapter_name, task_src])
elif adapter_config['ids'] == 'GROUP':
adapter_config['ids'] = list(src_groups)
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['encoder'].append([adapter_name, cc_opts['groups'][task_src]])
elif adapter_config['ids'] == 'FULL':
adapter_config['ids'] = ['full']
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_config['adapters']['encoder'].append([adapter_name, 'full'])
for adapter_name, adapter_config in sorted(decoder_adapters.items()):
if adapter_config['ids'] == 'LANGUAGE':
adapter_config['ids'] = list(tgt_langs)
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['decoder'].append([adapter_name, task_tgt])
elif adapter_config['ids'] == 'GROUP':
adapter_config['ids'] = list(tgt_groups)
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['decoder'].append([adapter_name, cc_opts['groups'][task_tgt]])
elif adapter_config['ids'] == 'FULL':
adapter_config['ids'] = ['full']
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_config['adapters']['decoder'].append([adapter_name, 'full'])
if len(encoder_adapters) > 0:
for adapter_name, adapter_config in sorted(encoder_adapters.items()):
if adapter_config['ids'] == 'LANGUAGE':
adapter_config['ids'] = list(src_langs)
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['encoder'].append([adapter_name, task_src])
elif adapter_config['ids'] == 'GROUP':
adapter_config['ids'] = list(src_groups)
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['encoder'].append([adapter_name, cc_opts['groups'][task_src]])
elif adapter_config['ids'] == 'FULL':
adapter_config['ids'] = ['full']
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_config['adapters']['encoder'].append([adapter_name, 'full'])
if len(decoder_adapters) > 0:
for adapter_name, adapter_config in sorted(decoder_adapters.items()):
if adapter_config['ids'] == 'LANGUAGE':
adapter_config['ids'] = list(tgt_langs)
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['decoder'].append([adapter_name, task_tgt])
elif adapter_config['ids'] == 'GROUP':
adapter_config['ids'] = list(tgt_groups)
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['decoder'].append([adapter_name, cc_opts['groups'][task_tgt]])
elif adapter_config['ids'] == 'FULL':
adapter_config['ids'] = ['full']
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_config['adapters']['decoder'].append([adapter_name, 'full'])
opts.in_config[0]['adapters']['encoder'] = encoder_adapters
opts.in_config[0]['adapters']['decoder'] = decoder_adapters

Expand Down

0 comments on commit 1cd30ef

Please sign in to comment.