Skip to content

Commit

Permalink
Keep a separate dict of attention_layer_blocks
Browse files Browse the repository at this point in the history
This structure allows retrieval of AdaptedAttentionLayers using the
layer_stack_index and the xcoder_id.
The existing structure requires knowing a task_id in which the component
is used. This task_id can be difficult to acquire in some contexts.
  • Loading branch information
Waino committed Sep 9, 2024
1 parent 1a96157 commit 803fb58
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
10 changes: 2 additions & 8 deletions mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,7 @@ def encoder_id(self) -> str:
return self.xcoder_id

def get_module(self, model: NMTModel) -> nn.Module:
a_task_id = sorted(self.task_ids)[0]
aal = model.encoder.get_attention_layers(a_task_id, self.layer_stack_index)
assert aal.xcoder_id == self.xcoder_id, \
f'{self.get_name()} {self.layer_stack_index}: expected {self.xcoder_id} found {aal.xcoder_id}'
aal = model.encoder.get_attention_layers_by_xcoder_id(self.layer_stack_index, self.xcoder_id)
return aal


Expand All @@ -149,10 +146,7 @@ def decoder_id(self) -> str:
return self.xcoder_id

def get_module(self, model: NMTModel) -> nn.Module:
a_task_id = sorted(self.task_ids)[0]
aal = model.decoder.get_attention_layers(a_task_id, self.layer_stack_index)
assert aal.xcoder_id == self.xcoder_id, \
f'{self.get_name()} {self.layer_stack_index}: expected {self.xcoder_id} found {aal.xcoder_id}'
aal = model.decoder.get_attention_layers_by_xcoder_id(self.layer_stack_index, self.xcoder_id)
return aal


Expand Down
6 changes: 5 additions & 1 deletion mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,11 @@ def build_xcoder(
transformer_wrappers[task.corpus_id] = transformer_wrapper

# Create a StackXcoder
stack_xcoder = StackXcoder(transformer_wrappers, token_embs=token_embs)
stack_xcoder = StackXcoder(
transformer_wrappers=transformer_wrappers,
attention_layer_blocks=attention_layer_blocks,
token_embs=token_embs,
)
return stack_xcoder


Expand Down
24 changes: 17 additions & 7 deletions mammoth/modules/layer_stack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch import nn
from typing import List, Sequence, Optional, Tuple
from x_transformers.x_transformers import LayerIntermediates
from typing import List, Sequence, Optional, Tuple, Dict
from x_transformers import TransformerWrapper
from x_transformers.x_transformers import LayerIntermediates, TokenEmbedding

from mammoth.modules.adapters import AdaptedAttentionLayers

Expand Down Expand Up @@ -59,10 +60,16 @@ class StackXcoder(nn.ModuleDict):
"""
Switches between different AdaptedAttentionLayersStacks depending on the task.
"""
def __init__(self, *args, token_embs=None, **kwargs):
super().__init__(*args, **kwargs)
self.active_task = None
self.token_embs = token_embs
def __init__(
self,
transformer_wrappers: Dict[str, TransformerWrapper],
attention_layer_blocks: Dict[int, Dict[str, AdaptedAttentionLayers]],
token_embs: Dict[str, TokenEmbedding],
):
super().__init__(transformer_wrappers)
self.attention_layers_by_xcoder_id: Dict[int, Dict[str, AdaptedAttentionLayers]] = attention_layer_blocks
self.token_embs: Dict[str, TokenEmbedding] = token_embs
self.active_task: Optional[str] = None

# TransformerWrapper wraps an AttentionLayers in embeddings and some other functionality.
# We use one TransformerWrapper per task.
Expand All @@ -76,9 +83,12 @@ def activate(self, task_id: str, adapter_ids: Optional[List[Tuple[int, str, str]
attention_layers_stack.activate_adapter(layer_stack_index, adapter_group, sub_id)
return transformer_wrapper

def get_attention_layers(self, task_id: str, layer_stack_index: int) -> AdaptedAttentionLayers:
def get_attention_layers_by_task_id(self, task_id: str, layer_stack_index: int) -> AdaptedAttentionLayers:
return self[task_id].attn_layers.attention_layers_stack[layer_stack_index]

def get_attention_layers_by_xcoder_id(self, layer_stack_index: int, xcoder_id: str) -> AdaptedAttentionLayers:
return self.attention_layers_by_xcoder_id[layer_stack_index][xcoder_id]

def get_embedding_by_task_id(self, task_id):
transformer_wrapper = self[task_id]
return transformer_wrapper.token_emb
Expand Down

0 comments on commit 803fb58

Please sign in to comment.