File tree 1 file changed +5
-5
lines changed
1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change 1
- """Implementation of an tanh softcapping score mod popularized in Gemma2 paper. """
1
+ """Implementation of tanh softcapping score mod popularized in Gemma2 and Grok-1 """
2
2
3
3
import torch
4
4
from torch import Tensor
11
11
12
12
13
13
@torch .library .custom_op ("approx::tanh" , mutates_args = ())
14
- def tanh_approx (inp : Tensor ) -> Tensor :
14
+ def _tanh_approx (inp : Tensor ) -> Tensor :
15
15
return torch .tanh (inp )
16
16
17
17
18
- @tanh_approx .register_fake
18
+ @_tanh_approx .register_fake
19
19
def _ (inp : torch .Tensor ) -> torch .Tensor :
20
20
return torch .tanh (inp )
21
21
22
22
23
- def tanh_approx_lowering (inp ):
23
+ def _tanh_approx_lowering (inp ):
24
24
fn = partial (ops .inline_asm_elementwise , asm = "tanh.approx.f32 $0, $1;" )
25
25
return make_pointwise (fn )(inp )
26
26
27
27
28
- register_lowering (torch .ops .approx .tanh )(tanh_approx_lowering )
28
+ register_lowering (torch .ops .approx .tanh )(_tanh_approx_lowering )
29
29
30
30
31
31
class _TanhApprox (torch .autograd .Function ):
You can’t perform that action at this time.
0 commit comments