Skip to content

Commit

Permalink
🐛 Llama bias bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
joey00072 committed Jul 18, 2024
1 parent 9f00955 commit a989359
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 37 deletions.
68 changes: 32 additions & 36 deletions ohara/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,22 @@ class Config:


class Attention(nn.Module):
def __init__(self, model_args: Config):
def __init__(self, cfg: Config):
super().__init__()
d_model = model_args.d_model
self.num_heads = model_args.num_heads
self.head_dim = model_args.d_model // model_args.num_heads
self.num_kv_heads = (
model_args.num_heads if model_args.num_kv_heads == 0 else model_args.num_kv_heads
)
d_model = cfg.d_model
self.num_heads = cfg.num_heads
self.head_dim = cfg.d_model // cfg.num_heads
self.num_kv_heads = cfg.num_heads if cfg.num_kv_heads == 0 else cfg.num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.key = nn.Linear(d_model, self.head_dim * self.num_heads)
self.query = nn.Linear(d_model, self.head_dim * self.num_kv_heads)
self.value = nn.Linear(d_model, self.head_dim * self.num_kv_heads)
self.proj = nn.Linear(d_model, d_model, model_args.bias)
self.attn_dropout = nn.Dropout(model_args.dropout)
self.res_dropout = nn.Dropout(model_args.dropout)
self.key = nn.Linear(d_model, self.head_dim * self.num_heads, cfg.bias)
self.query = nn.Linear(d_model, self.head_dim * self.num_kv_heads, cfg.bias)
self.value = nn.Linear(d_model, self.head_dim * self.num_kv_heads, cfg.bias)
self.proj = nn.Linear(d_model, d_model, cfg.bias)

self.attn_dropout = nn.Dropout(cfg.dropout)
self.res_dropout = nn.Dropout(cfg.dropout)

self.flash_attn = hasattr(torch.nn.functional, "scaled_dot_product_attention")

Expand Down Expand Up @@ -103,19 +101,19 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor, freqs_cis) -> torch.Tenso


class Block(nn.Module):
def __init__(self, model_args: Config):
def __init__(self, cfg: Config):
super().__init__()

self.attn = Attention(model_args)
self.attn = Attention(cfg)
self.ff = SwiGLU(
dim=model_args.d_model,
hidden_dim=model_args.hidden_dim,
dropout=model_args.dropout,
bias=model_args.bias,
dim=cfg.d_model,
hidden_dim=cfg.hidden_dim,
dropout=cfg.dropout,
bias=cfg.bias,
)

self.norm1 = RMSNorm(model_args.d_model)
self.norm2 = RMSNorm(model_args.d_model)
self.norm1 = RMSNorm(cfg.d_model)
self.norm2 = RMSNorm(cfg.d_model)

def forward(self, x, mask, freqs_cis):
x = x + self.attn(self.norm1(x), mask, freqs_cis)
Expand All @@ -124,30 +122,28 @@ def forward(self, x, mask, freqs_cis):


class LLAMA(nn.Module):
def __init__(self, model_args: Config, *args, **kwargs) -> None:
def __init__(self, cfg: Config, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.config = model_args
self.config = cfg

self.token_emb = nn.Embedding(model_args.vocab_size, model_args.d_model)
self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)

self.layers = nn.ModuleList([Block(model_args) for _ in range(model_args.num_layers)])
self.layers = nn.ModuleList([Block(cfg) for _ in range(cfg.num_layers)])

self.norm = RMSNorm(model_args.d_model)
self.vocab_proj = nn.Linear(model_args.d_model, model_args.vocab_size, bias=False)
self.norm = RMSNorm(cfg.d_model)
self.vocab_proj = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

if model_args.weight_tying:
if cfg.weight_tying:
self.token_emb.weight = self.vocab_proj.weight

cos,isin = precompute_freqs_cis(
model_args.d_model // model_args.num_heads, model_args.seq_len * 2
)
self.register_buffer("freq_cos",cos)
self.register_buffer("freq_sin",isin)
cos, isin = precompute_freqs_cis(cfg.d_model // cfg.num_heads, cfg.seq_len * 2)
self.register_buffer("freq_cos", cos)
self.register_buffer("freq_sin", isin)

if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
print("WARNING: using slow attention | upgrade pytorch to 2.0 or above")
mask = torch.full((1, 1, model_args.seq_len, model_args.seq_len), float("-inf"))
mask = torch.full((1, 1, cfg.seq_len, cfg.seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)
else:
Expand All @@ -160,7 +156,7 @@ def forward(self, x: torch.Tensor):
x = self.token_emb(x)
device = self.token_emb.weight.device
freqs_cis = self.freq_cos[:seqlen], self.freq_sin[:seqlen]

for layer in self.layers:
x = layer(x, self.mask, freqs_cis)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ select = [
"PERF102",
"UP006", "UP007",
"FURB148", "FURB163", "FURB181",
"ASYNC100", "ASYNC102",
"ASYNC100",
"TID251",
]
ignore = [
Expand Down

0 comments on commit a989359

Please sign in to comment.