Skip to content

Commit

Permalink
WIP: Replace the custom Transformer implementation with x-transformers
Browse files Browse the repository at this point in the history
External dependencies for layer architectures #56
  • Loading branch information
Waino committed Jul 29, 2024
1 parent 4a32776 commit e09763f
Show file tree
Hide file tree
Showing 13 changed files with 679 additions and 1,261 deletions.
12 changes: 6 additions & 6 deletions config/example.hamburger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,43 +20,43 @@ adapters:
enc_group:
layer_stack_index: 0
layers: [0, 1]
hidden_size: 8 # 512 (rnn_size) / 64 (reduction factor)
hidden_dim: 8 # 512 (rnn_size) / 64 (reduction factor)
ids:
- foo
- bar
enc_highresource:
layer_stack_index: 0
layers: [0, 1]
hidden_size: 8
hidden_dim: 8
ids:
- en
- de
enc_lowresource:
layer_stack_index: 0
layers: [0]
hidden_size: 8
hidden_dim: 8
ids:
- uu
decoder:
dec_group:
layer_stack_index: 0
layers: [0]
hidden_size: 8
hidden_dim: 8
ids:
- foo
- bar
dec_highresource:
layer_stack_index: 1
layers: [0, 1]
hidden_size: 16
hidden_dim: 16
ids:
- en
- de
- fr
dec_lowresource:
layer_stack_index: 1
layers: [0]
hidden_size: 8
hidden_dim: 8
ids:
- vv

Expand Down
12 changes: 6 additions & 6 deletions examples/config_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,28 +70,28 @@ adapters:
enc_lang_bottom:
layer_stack_index: 0
layers: [0, 1, 2]
hidden_size: 8
hidden_dim: 8
ids: LANGUAGE
enc_lang_top:
layer_stack_index: 1
layers: [0, 1, 2]
hidden_size: 8
hidden_dim: 8
ids: LANGUAGE
decoder:
dec_lang_bottom:
layer_stack_index: 0
layers: [0, 1]
hidden_size: 16
hidden_dim: 16
ids: LANGUAGE
dec_lang_mid:
layer_stack_index: 1
layers: [0, 1, 2]
hidden_size: 16
hidden_dim: 16
ids: LANGUAGE
dec_lang_top:
layer_stack_index: 2
layers: [0]
hidden_size: 16
hidden_dim: 16
ids: LANGUAGE

save_model: models/opus.spm32k.adafactor.hamburger.l2.dsae/opus.spm32k.adafactor.hamburger.l2.dsae
Expand All @@ -107,7 +107,7 @@ encoder_type: transformer
decoder_type: transformer
rnn_size: 512
word_vec_size: 512
transformer_ff: 2048
ff_mult: 4
heads: 8
enc_layers: [3, 3]
dec_layers: [2, 3, 1]
Expand Down
1 change: 1 addition & 0 deletions mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def needs_communication(self) -> bool:
return self.group is not None


# TODO: This is a misnomer: Not an entire XCoder, but just one AttentionLayers block
@dataclass
class DistributedXCoder(DistributedComponent, ABC):
layer_stack_index: int
Expand Down
Loading

0 comments on commit e09763f

Please sign in to comment.