Skip to content

Commit

Permalink
Support for Microsoft's Phi-2 model (#2548)
Browse files Browse the repository at this point in the history
* phi-2 support
  • Loading branch information
vince62s authored Jan 11, 2024
1 parent a987e7d commit 8045a86
Show file tree
Hide file tree
Showing 6 changed files with 1,001 additions and 10 deletions.
18 changes: 18 additions & 0 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
rotary_dim=0,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -87,6 +88,9 @@ def __init__(
rotary_interleave (bool): Interleave the head dimensions when rotary
embeddings are applied
rotary_theta (int): rotary base theta
rotary_dim (int): in some cases the rotary dim is lower than head dim
num_experts (int): Number of experts for MoE
num_experts_per_tok (int): Number of experts choice per token
"""
super(TransformerDecoderLayerBase, self).__init__()

Expand All @@ -99,6 +103,7 @@ def __init__(
relative_positions_buckets=relative_positions_buckets,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
rotary_dim=rotary_dim,
attn_type="self",
self_attn_type=self_attn_type,
add_qkvbias=add_qkvbias,
Expand Down Expand Up @@ -280,6 +285,7 @@ def __init__(
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
rotary_dim=0,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -312,6 +318,7 @@ def __init__(
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
rotary_dim=rotary_dim,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
Expand Down Expand Up @@ -475,6 +482,7 @@ def from_opt(cls, opt, embeddings):
sliding_window=opt.sliding_window,
rotary_interleave=opt.rotary_interleave,
rotary_theta=opt.rotary_theta,
rotary_dim=opt.rotary_dim,
num_experts=opt.num_experts,
num_experts_per_tok=opt.num_experts_per_tok,
)
Expand Down Expand Up @@ -566,6 +574,9 @@ class TransformerDecoder(TransformerDecoderBase):
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
rotary_theta (int): rotary base theta
rotary_dim (int): in some cases the rotary dim is lower than head dim
num_experts (int): Number of experts for MoE
num_experts_per_tok (int): Number of experts choice per token
"""

def __init__(
Expand Down Expand Up @@ -598,6 +609,7 @@ def __init__(
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
rotary_dim=0,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -632,6 +644,7 @@ def __init__(
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
rotary_dim=rotary_dim,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
Expand Down Expand Up @@ -842,6 +855,9 @@ class TransformerLMDecoder(TransformerDecoderBase):
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
rotary_theta (int): rotary base theta
rotary_dim (int): in some cases the rotary dim is lower than head dim
num_experts (int): Number of experts for MoE
num_experts_per_tok (int): Number of experts choice per token
"""

def __init__(
Expand Down Expand Up @@ -874,6 +890,7 @@ def __init__(
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
rotary_dim=0,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -907,6 +924,7 @@ def __init__(
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
rotary_dim=rotary_dim,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
Expand Down
6 changes: 6 additions & 0 deletions onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class TransformerEncoderLayer(nn.Module):
rotary_interleave (bool): Interleave the head dimensions when rotary
embeddings are applied
rotary_theta (int): rotary base theta
rotary_dim (int): rotary dim when different to dim per head
"""

def __init__(
Expand All @@ -63,6 +64,7 @@ def __init__(
parallel_gpu=1,
rotary_interleave=True,
rotary_theta=1e4,
rotary_dim=0,
):
super(TransformerEncoderLayer, self).__init__()

Expand All @@ -75,6 +77,7 @@ def __init__(
relative_positions_buckets=relative_positions_buckets,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
rotary_dim=rotary_dim,
attn_type="self",
add_qkvbias=add_qkvbias,
num_kv=num_kv,
Expand Down Expand Up @@ -181,6 +184,7 @@ def __init__(
parallel_gpu=1,
rotary_interleave=True,
rotary_theta=1e4,
rotary_dim=0,
):
super(TransformerEncoder, self).__init__()

Expand All @@ -206,6 +210,7 @@ def __init__(
parallel_gpu=parallel_gpu,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
rotary_dim=rotary_dim,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -245,6 +250,7 @@ def from_opt(cls, opt, embeddings):
else 1,
rotary_interleave=opt.rotary_interleave,
rotary_theta=opt.rotary_theta,
rotary_dim=opt.rotary_dim,
)

def forward(self, src, src_len=None):
Expand Down
31 changes: 24 additions & 7 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,20 @@ def apply_rotary_emb(query, key, rope, interleave):
).type_as(key)
else:
cos, sin = rope.real, rope.imag
q_embed = (query * cos) + (rotate_half(query) * sin)
k_embed = (key * cos) + (rotate_half(key) * sin)
rotary_dim = cos.size(1)
head_dim = query.size(3)
if rotary_dim < head_dim:
q_embed = (query[:, :, :, :rotary_dim] * cos) + (
rotate_half(query[:, :, :, :rotary_dim]) * sin
)
k_embed = (key[:, :, :, :rotary_dim] * cos) + (
rotate_half(key[:, :, :, :rotary_dim]) * sin
)
q_embed = torch.cat([q_embed, query[:, :, :, rotary_dim:]], dim=-1)
k_embed = torch.cat([k_embed, key[:, :, :, rotary_dim:]], dim=-1)
else:
q_embed = (query * cos) + (rotate_half(query) * sin)
k_embed = (key * cos) + (rotate_half(key) * sin)
return q_embed.type_as(query), k_embed.type_as(key)


Expand Down Expand Up @@ -258,6 +270,7 @@ def __init__(
relative_positions_buckets: int = 0,
rotary_interleave: bool = True,
rotary_theta: int = 1e4,
rotary_dim: int = 0,
attn_type: str = None,
self_attn_type: str = None,
add_qkvbias=False,
Expand Down Expand Up @@ -352,7 +365,11 @@ def __init__(
self.relative_attention_bias = None

if max_relative_positions == -1: # rotary embeddings
self.rope = rotaryembeddings(self.dim_per_head, base=rotary_theta)
if rotary_dim == 0:
self.rotary_dim = self.dim_per_head
else:
self.rotary_dim = rotary_dim
self.rope = rotaryembeddings(self.rotary_dim, base=rotary_theta)
self.cos = (
self.rope[:, : self.rope.size(1) // 2].real.contiguous().half()
)
Expand Down Expand Up @@ -431,10 +448,10 @@ def forward(
self.linear_keys(query),
self.linear_values(query),
)

query = shape(query, self.dim_per_head)
key = shape(key, self.dim_per_head)
value = shape(value, self.dim_per_head)

start_pos = step
seqlen = query.size(2)

Expand All @@ -449,7 +466,7 @@ def forward(
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head,
self.rotary_dim,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
Expand All @@ -472,7 +489,7 @@ def forward(
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head,
self.rotary_dim,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
Expand Down Expand Up @@ -577,7 +594,7 @@ def forward(
seqlen = query.size(2)
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head,
self.rotary_dim,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
Expand Down
7 changes: 7 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,13 @@ def model_opts(parser):
default=10000,
help="Rotary theta base length" "1e4 for Llama2.Mistral" "1e6 for Mixtral",
)
group.add(
"--rotary_dim",
"-rotary_dim",
type=int,
default=0,
help="Rotary dim when model requires it to be different to head dim",
)
group.add(
"--heads",
"-heads",
Expand Down
6 changes: 3 additions & 3 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True):
)

# (4) Begin decoding step by step:
beg_time = time()
# beg_time = time()
for step in range(decode_strategy.max_length):
decoder_input = (
src if step == 0 else decode_strategy.current_predictions.view(-1, 1, 1)
Expand Down Expand Up @@ -1153,8 +1153,8 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True):
if parallel_paths > 1 or any_finished:
# select indexes in model state/cache
self.model.decoder.map_state(lambda state, dim: state[select_indices])
if step == 0:
print("step0 time: ", time() - beg_time)
# if step == 0:
# print("step0 time: ", time() - beg_time)

return self.report_results(
gold_score,
Expand Down
Loading

0 comments on commit 8045a86

Please sign in to comment.