Skip to content

Commit 8fa7b4c

Browse files
committed
help out @cutoken at #159
1 parent c7cc222 commit 8fa7b4c

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'x-transformers',
55
packages = find_packages(exclude=['examples']),
6-
version = '1.16.9',
6+
version = '1.16.10',
77
license='MIT',
88
description = 'X-Transformers - Pytorch',
99
author = 'Phil Wang',

x_transformers/x_transformers.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -398,12 +398,16 @@ def __init__(
398398
self,
399399
dim,
400400
use_xpos = False,
401-
scale_base = 512
401+
scale_base = 512,
402+
interpolation_factor = 1.
402403
):
403404
super().__init__()
404405
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
405406
self.register_buffer('inv_freq', inv_freq)
406407

408+
assert interpolation_factor >= 1.
409+
self.interpolation_factor = interpolation_factor
410+
407411
if not use_xpos:
408412
self.register_buffer('scale', None)
409413
return
@@ -415,6 +419,8 @@ def __init__(
415419

416420
def forward(self, seq_len, device):
417421
t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
422+
t = t / self.interpolation_factor
423+
418424
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
419425
freqs = torch.cat((freqs, freqs), dim = -1)
420426

@@ -887,6 +893,7 @@ def __init__(
887893
rotary_emb_dim = None,
888894
rotary_xpos = False,
889895
rotary_xpos_scale_base = 512,
896+
rotary_interpolation_factor = 1.,
890897
custom_layers = None,
891898
sandwich_coef = None,
892899
par_ratio = None,
@@ -925,7 +932,7 @@ def __init__(
925932
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
926933

927934
assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
928-
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base) if rotary_pos_emb else None
935+
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor) if rotary_pos_emb else None
929936

930937
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
931938
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'

0 commit comments

Comments
 (0)