Skip to content

Commit

Permalink
reimplemented changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Attieh authored and Joseph Attieh committed Oct 14, 2024
1 parent 97bc2d9 commit 1ed939f
Show file tree
Hide file tree
Showing 4 changed files with 2,482 additions and 4 deletions.
20 changes: 18 additions & 2 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial
from torch.nn.init import xavier_uniform_
from typing import Optional, List, Dict, Tuple
from x_transformers import TransformerWrapper
from mammoth.modules.x_tf import TransformerWrapper
from x_transformers.x_transformers import TokenEmbedding

from mammoth.distributed.components import (
Expand All @@ -31,6 +31,20 @@
from mammoth.utils.logging import logger
from mammoth.utils.misc import use_gpu

from torch.nn import Module
# embedding
import torch.nn.functional as F
class ByteEmbedding(Module):
def __init__(self, dim, num_tokens, l2norm_embed = False):
super().__init__()
self.emb = nn.Embedding(num_tokens, dim)
one_hot_matrix = F.one_hot(torch.arange(num_tokens)).float()
one_hot_embed = torch.cat((one_hot_matrix, torch.zeros((num_tokens, dim - num_tokens))), dim=1)
self.emb.weight = torch.nn.parameter.Parameter(one_hot_embed, requires_grad=False)
def forward(self, x):
token_emb = self.emb(x.long())
return token_emb

TRANSFORMER_WRAPPER_OPTS = {
'post_emb_norm',
'tie_embedding',
Expand Down Expand Up @@ -210,7 +224,8 @@ def build_xcoder(
for lang in all_langs:
if lang not in token_embs:
vocab = vocabs_dict[(side_alt_str, lang)]
token_embs[lang] = TokenEmbedding(
Embedding = ByteEmbedding if model_opts.use_embeddingless else TokenEmbedding
token_embs[lang] = Embedding(
dim=model_opts.model_dim,
num_tokens=len(vocab),
l2norm_embed=l2norm_embed
Expand Down Expand Up @@ -245,6 +260,7 @@ def build_xcoder(
attn_layers=adapted_attention_layers_stack,
emb_dim=model_opts.model_dim,
token_emb=token_embs[lang],
initialize_embeddings=not (model_opts.use_embeddingless),
**transformer_wrapper_kwargs,
)
transformer_wrappers[task.corpus_id] = transformer_wrapper
Expand Down
Loading

0 comments on commit 1ed939f

Please sign in to comment.