-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
- Loading branch information
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
absl-py==0.3.0 | ||
astor==0.7.1 | ||
autopep8==1.3.5 | ||
backcall==0.1.0 | ||
bleach==2.1.4 | ||
certifi==2018.8.24 | ||
chardet==3.0.4 | ||
colorama==0.3.9 | ||
cycler==0.10.0 | ||
decorator==4.3.0 | ||
defusedxml==0.5.0 | ||
entrypoints==0.2.3 | ||
gast==0.2.0 | ||
grpcio==1.14.1 | ||
html5lib==1.0.1 | ||
idna==2.7 | ||
ipykernel==5.0.0 | ||
ipython==7.0.1 | ||
ipython-genutils==0.2.0 | ||
ipywidgets==7.4.2 | ||
isort==4.3.4 | ||
jedi==0.12.1 | ||
Jinja2==2.10 | ||
jsonschema==2.6.0 | ||
jupyter==1.0.0 | ||
jupyter-client==5.2.3 | ||
jupyter-console==5.2.0 | ||
jupyter-core==4.4.0 | ||
kiwisolver==1.0.1 | ||
lxml==4.2.5 | ||
Markdown==2.6.11 | ||
MarkupSafe==1.0 | ||
matplotlib==2.2.2 | ||
mccabe==0.6.1 | ||
mistune==0.8.3 | ||
nbconvert==5.4.0 | ||
nbformat==4.4.0 | ||
nltk==3.3 | ||
notebook==5.7.0 | ||
numpy==1.14.5 | ||
opencv-python==3.4.2.17 | ||
pandas==0.23.4 | ||
pandas-datareader==0.7.0 | ||
pandocfilters==1.4.2 | ||
parso==0.3.1 | ||
pickleshare==0.7.5 | ||
Pillow==5.2.0 | ||
prometheus-client==0.3.1 | ||
prompt-toolkit==1.0.15 | ||
protobuf==3.6.0 | ||
pycodestyle==2.4.0 | ||
Pygments==2.2.0 | ||
pyparsing==2.2.0 | ||
python-dateutil==2.7.3 | ||
pytz==2018.5 | ||
pywinpty==0.5.4 | ||
pyzmq==17.1.2 | ||
qtconsole==4.4.1 | ||
requests==2.19.1 | ||
scikit-learn==0.19.2 | ||
scipy==1.1.0 | ||
Send2Trash==1.5.0 | ||
simplegeneric==0.8.1 | ||
six==1.11.0 | ||
tensorboard==1.10.0 | ||
tensorboardX==1.4 | ||
tensorflow==1.10.0 | ||
termcolor==1.1.0 | ||
terminado==0.8.1 | ||
testpath==0.4.1 | ||
torch==0.4.1 | ||
torchfile==0.1.0 | ||
torchnet==0.0.4 | ||
torchvision==0.2.1 | ||
tornado==5.1.1 | ||
traitlets==4.3.2 | ||
urllib3==1.23 | ||
visdom==0.1.8.5 | ||
wcwidth==0.1.7 | ||
webencodings==0.5.1 | ||
websocket-client==0.53.0 | ||
Werkzeug==0.14.1 | ||
widgetsnbextension==3.4.2 | ||
wrapt==1.10.11 | ||
xgboost==0.80 |
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import math | ||
import torch | ||
from torch import nn | ||
from typing import * | ||
from transformers.models.llama import LlamaConfig, LlamaForCausalLM | ||
|
||
class LlamaRotaryEmbedding(nn.Module): | ||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | ||
super().__init__() | ||
|
||
self.dim = dim | ||
self.max_position_embeddings = max_position_embeddings | ||
self.base = base | ||
|
||
# \theta = 10000 ^ {-2 i / d}, (head_dim, ) | ||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) | ||
self.register_buffer("inv_freq", inv_freq, persistent=False) | ||
|
||
# Build here to make `torch.jit.trace` work. | ||
self._set_cos_sin_cache( | ||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() | ||
) | ||
|
||
def _set_cos_sin_cache(self, seq_len, device, dtype): | ||
|
||
# m \theta, (sequence_length, head_dim) | ||
self.max_seq_len_cached = seq_len | ||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) | ||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||
|
||
# Different from paper, but it uses a different permutation in order to obtain the same calculation | ||
# m \theta_0, m \theta_1, \cdots, m \theta_{d/2-1} | m \theta_0, m \theta_1, \cdots, m \theta_{d/2-1} | ||
emb = torch.cat((freqs, freqs), dim=-1) | ||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) | ||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) | ||
|
||
def forward(self, x, seq_len=None): | ||
return ( | ||
self.cos_cached[:seq_len].to(dtype=x.dtype), | ||
self.sin_cached[:seq_len].to(dtype=x.dtype), | ||
) | ||
|
||
class LlamaAttention(nn.Module): | ||
"""Multi-headed attention from 'Attention Is All You Need' paper""" | ||
|
||
def __init__(self, config: LlamaConfig): | ||
super().__init__() | ||
self.config = config | ||
self.hidden_size = config.hidden_size | ||
self.num_heads = config.num_attention_heads | ||
self.head_dim = self.hidden_size // self.num_heads | ||
self.max_position_embeddings = config.max_position_embeddings | ||
self.rope_theta = config.rope_theta | ||
self.is_causal = True | ||
|
||
# num_heads * head_dim == hidden_size | ||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) | ||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) | ||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) | ||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) | ||
|
||
self.rotary_emb = LlamaRotaryEmbedding( | ||
self.head_dim, | ||
max_position_embeddings=self.max_position_embeddings, | ||
base=self.rope_theta, | ||
) | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: bool = False, | ||
use_cache: bool = False, | ||
**kwargs, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
|
||
# (batch_size, sequence_length, hidden_size) | ||
bsz, q_len, _ = hidden_states.size() | ||
|
||
# (batch_size, sequence_length, num_heads * head_dim) | ||
query_states = self.q_proj(hidden_states) | ||
key_states = self.k_proj(hidden_states) | ||
value_states = self.v_proj(hidden_states) | ||
|
||
# (batch_size, num_heads, sequence_length, head_dim) | ||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
|
||
def rotate_half(x): | ||
"""Rotates half the hidden dims of the input.""" | ||
# - x_{d/2}, \cdots, - x_{d-1} | x_0, \cdots x_{d/2-1} | ||
x1 = x[..., : x.shape[-1] // 2] | ||
x2 = x[..., x.shape[-1] // 2 :] | ||
return torch.cat((-x2, x1), dim=-1) | ||
|
||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): | ||
# (sequence_length, head_dim) -> (batch_size, 1, sequence_length, head_dim) | ||
cos = cos[position_ids].unsqueeze(1) | ||
sin = sin[position_ids].unsqueeze(1) | ||
|
||
# x_i 与 x_{i + d/2} 作为一对进行旋转 | ||
# (batch_size, num_heads, sequence_length, head_dim) | ||
q_embed = (q * cos) + (rotate_half(q) * sin) | ||
k_embed = (k * cos) + (rotate_half(k) * sin) | ||
|
||
return q_embed, k_embed | ||
|
||
# (kv_sequence_length, head_dim) | ||
kv_seq_len = key_states.shape[-2] | ||
""" | ||
if past_key_value is not None: | ||
kv_seq_len += past_key_value[0].shape[-2] | ||
""" | ||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | ||
|
||
""" | ||
if past_key_value is not None: | ||
# reuse k, v, self_attention | ||
key_states = torch.cat([past_key_value[0], key_states], dim=2) | ||
value_states = torch.cat([past_key_value[1], value_states], dim=2) | ||
past_key_value = (key_states, value_states) if use_cache else None | ||
""" | ||
|
||
# (batch_size, num_heads, sequence_length, hidden_size) | ||
# (batch_size, num_heads, hidden_size, kv_sequence_length) | ||
# -> (batch_size, num_heads, sequence_length, kv_sequence_length) | ||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) | ||
if attention_mask is not None: | ||
attn_weights = attn_weights + attention_mask | ||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) # upcast attention to fp32 | ||
|
||
# (batch_size, num_heads, sequence_length, kv_sequence_length) | ||
# (batch_size, num_heads, kv_sequence_length, head_dim) | ||
# -> (batch_size, num_heads, sequence_length, head_dim) | ||
attn_output = torch.matmul(attn_weights, value_states) | ||
|
||
# (batch_size, sequence_length, num_heads, head_dim) | ||
attn_output = attn_output.transpose(1, 2).contiguous() | ||
|
||
# (batch_size, sequence_length, hidden_size) | ||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | ||
|
||
# (batch_size, sequence_length, hidden_size) | ||
attn_output = self.o_proj(attn_output) | ||
|
||
return attn_output, attn_weights, past_key_value | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# TODO: todel |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import os | ||
import torch | ||
from torch import nn | ||
|
||
class LlamaRotaryEmbedding(torch.nn.Module): | ||
|
||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, type="standard", scaling_factor=1.0): | ||
super().__init__() | ||
|
||
if type == "ntk-scaling": | ||
base = base * scaling_factor ** (dim / (dim - 2)) # ntk | ||
|
||
# shape(hidden_size // 2, ), θ_i, i = 0, \cdots, d_k / 2 - 1 | ||
# θ_0, θ_1, ..., θ_{d_k / 2 - 1} | ||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) | ||
self.register_buffer("inv_freq", inv_freq) | ||
|
||
# Build here to make `torch.jit.trace` work. | ||
self.max_seq_len_cached = max_position_embeddings | ||
# shape(max_position_embeddings, ), positions | ||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) | ||
|
||
if type == "linear-interpolation": | ||
t = t / scaling_factor # linear interpolation | ||
|
||
# shape(max_position_embeddings, hidden_size // 2) | ||
# 0 * θ_0 * θ_1, ..., 0 * θ_{d_k / 2 - 1} | ||
# 1 * θ_0, 1 * θ_1, ..., 1 * θ_{d_k / 2 - 1} | ||
# ... | ||
# t * θ_0, t * θ_1, ..., t * θ_{d_k / 2 - 1} | ||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||
# Different from paper, but it uses a different permutation in order to obtain the same calculation | ||
# shape(max_position_embeddings, hidden_size) | ||
# 0 * θ_0 * θ_1, ..., 0 * θ_{d_k / 2 - 1} | 0 * θ_0 * θ_1, ..., 0 * θ_{d_k / 2 - 1} | ||
# 1 * θ_0, 1 * θ_1, ..., 1 * θ_{d_k / 2 - 1} | 1 * θ_0, 1 * θ_1, ..., 1 * θ_{d_k / 2 - 1} | ||
# ... | ... | ||
# t * θ_0, t * θ_1, ..., t * θ_{d_k / 2 - 1} | t * θ_0, t * θ_1, ..., t * θ_{d_k / 2 - 1} | ||
emb = torch.cat((freqs, freqs), dim=-1) | ||
# shape(1, 1, max_position_embeddings, hidden_size) | ||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) | ||
# shape(1, 1, max_position_embeddings, hidden_size) | ||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) | ||
|
||
def forward(self, x, seq_len=None): | ||
# x: [bs, num_attention_heads, seq_len, head_size] | ||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. | ||
if seq_len > self.max_seq_len_cached: | ||
self.max_seq_len_cached = seq_len | ||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) | ||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||
# Different from paper, but it uses a different permutation in order to obtain the same calculation | ||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | ||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) | ||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) | ||
# shape(1, 1, sequence_length, hidden_size) | ||
return ( | ||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | ||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | ||
) | ||
|
||
if __name__ == "__main__": | ||
import matplotlib.pyplot as plt | ||
|
||
scaling_factor = 4.0 | ||
types = ["standard", "linear-interpolation", "ntk-scaling"] | ||
fig, axs = plt.subplots(3, 1, figsize=(10, 10)) | ||
|
||
for i in range(len(types)): | ||
type = types[i] | ||
|
||
dim = 768 * 2 # y | ||
max_position_embeddings = 512 * int(scaling_factor) # x | ||
|
||
if type == "standard": | ||
rope = LlamaRotaryEmbedding( | ||
dim=dim, max_position_embeddings=max_position_embeddings, type=type, scaling_factor=1.0, | ||
) | ||
|
||
elif type == "linear-interpolation": | ||
rope = LlamaRotaryEmbedding( | ||
dim=dim, max_position_embeddings=max_position_embeddings, type=type, scaling_factor=scaling_factor, | ||
) | ||
|
||
elif type == "ntk-scaling": | ||
rope = LlamaRotaryEmbedding( | ||
dim=dim, max_position_embeddings=max_position_embeddings, type=type, scaling_factor=scaling_factor, | ||
) | ||
|
||
xticks = [i for i in range(0, max_position_embeddings + 1, 256)] | ||
yticks = [i for i in range(0, dim // 2 + 1, 128)] | ||
|
||
# twin_axs = axs[i].twinx() | ||
|
||
axs[i].set_title([ | ||
"Sinusoidal Position Embedding (Standard)", | ||
"Sinusoidal Position Embedding (Linear Interpolation)", | ||
"Sinusoidal Position Embedding (NTK-Scaling)", | ||
][i]) | ||
|
||
cos_im = torch.flip(rope.cos_cached[0][0][..., :dim // 2].T, dims=[0]) # 上下翻转 | ||
if i == 0: | ||
cos_im[:, int(max_position_embeddings/scaling_factor):] = 0.0 | ||
axs[i].imshow(cos_im, cmap="coolwarm") | ||
|
||
axs[i].set_xticks(ticks=xticks, labels=xticks) | ||
axs[i].set_yticks(ticks=yticks, labels=yticks[::-1]) | ||
# twin_axs.set_yticks(ticks=yticks, labels=yticks[::-1]) | ||
|
||
if i == 2: | ||
axs[i].set_xlabel("position(x)") | ||
|
||
axs[i].set_ylabel("dimension(i)") | ||
# twin_axs.set_ylabel("frequency(i)") | ||
|
||
if i > 0: | ||
axs[i].axvline(x=int(max_position_embeddings/scaling_factor), color='r', linestyle='--') | ||
|
||
plt.axis('on') # 可以选择是否显示坐标轴 | ||
plt.show() |