-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathsoftcapping.py
104 lines (70 loc) · 2.74 KB
/
softcapping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Implementation of tanh softcapping score mod popularized in Gemma2 and Grok-1"""
import torch
from torch import Tensor
from torch.nn.attention.flex_attention import _score_mod_signature
from torch._inductor.lowering import make_pointwise, register_lowering
# Some internal torch.compile details
from torch._inductor.virtualized import ops
from functools import partial
@torch.library.custom_op("approx::tanh", mutates_args=())
def _tanh_approx(inp: Tensor) -> Tensor:
return torch.tanh(inp)
@_tanh_approx.register_fake
def _(inp: torch.Tensor) -> torch.Tensor:
return torch.tanh(inp)
def _tanh_approx_lowering(inp):
fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;")
return make_pointwise(fn)(inp)
register_lowering(torch.ops.approx.tanh)(_tanh_approx_lowering)
class _TanhApprox(torch.autograd.Function):
@staticmethod
def forward(x):
return torch.ops.approx.tanh(x)
@staticmethod
def setup_context(ctx, inputs, output):
(x,) = inputs
result = output
ctx.save_for_backward(result)
@staticmethod
def backward(ctx, grad_output):
(result,) = ctx.saved_tensors
return grad_output * (1 - result * result)
@staticmethod
def vmap(info, in_dims, x):
return torch.tanh(x), 0
_tanh_approx = _TanhApprox.apply
def generate_tanh_softcap(soft_cap: int, approx: bool = False) -> _score_mod_signature:
"""Returns an tanh bias score_mod given the number of heads H
Args:
soft_cap: The soft cap value to use for normalizing logits
approx: Whether to use the `tanh.approx.` ptx instruction
Returns:
tanh_softcap: score_mod
"""
tanh = _tanh_approx if approx else torch.tanh
def tanh_softcap(score, b, h, q_idx, kv_idx):
return soft_cap * tanh(score / soft_cap)
prefix = "tanh_softcap_approx" if approx else "tanh_softcap"
tanh_softcap.__name__ = f"{prefix}_{soft_cap}"
return tanh_softcap
def main(device: str = "cpu"):
"""Visualize the attention scores tanh_softcap score mod.
Args:
device (str): Device to use for computation. Defaults
"""
import torch
from attn_gym import visualize_attention_scores
B, H, SEQ_LEN, HEAD_DIM = 1, 1, 12, 8
def make_tensor():
return torch.rand(B, H, SEQ_LEN, HEAD_DIM, device=device)
query, key = make_tensor(), make_tensor()
tanh_softcap_score_mod = generate_tanh_softcap(30, approx=True)
visualize_attention_scores(
query, key, score_mod=tanh_softcap_score_mod, device=device, name="tanh_softcap_score_mod"
)
if __name__ == "__main__":
try:
from jsonargparse import CLI
except ImportError:
raise ImportError("Be sure to run: pip install -e .'[viz]'")
CLI(main)