Skip to content

Commit 21f1235

Browse files
committed
clean up doc string
1 parent 757b03d commit 21f1235

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

attn_gym/mods/softcapping.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Implementation of an tanh softcapping score mod popularized in Gemma2 paper."""
1+
"""Implementation of tanh softcapping score mod popularized in Gemma2 and Grok-1"""
22

33
import torch
44
from torch import Tensor
@@ -11,21 +11,21 @@
1111

1212

1313
@torch.library.custom_op("approx::tanh", mutates_args=())
14-
def tanh_approx(inp: Tensor) -> Tensor:
14+
def _tanh_approx(inp: Tensor) -> Tensor:
1515
return torch.tanh(inp)
1616

1717

18-
@tanh_approx.register_fake
18+
@_tanh_approx.register_fake
1919
def _(inp: torch.Tensor) -> torch.Tensor:
2020
return torch.tanh(inp)
2121

2222

23-
def tanh_approx_lowering(inp):
23+
def _tanh_approx_lowering(inp):
2424
fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;")
2525
return make_pointwise(fn)(inp)
2626

2727

28-
register_lowering(torch.ops.approx.tanh)(tanh_approx_lowering)
28+
register_lowering(torch.ops.approx.tanh)(_tanh_approx_lowering)
2929

3030

3131
class _TanhApprox(torch.autograd.Function):

0 commit comments

Comments
 (0)