Skip to content

Commit a957b68

Browse files
committed
Revert "help out @cutoken at #159"
This reverts commit 8fa7b4c.
1 parent baa2d6e commit a957b68

File tree

2 files changed

+3
-10
lines changed

2 files changed

+3
-10
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'x-transformers',
55
packages = find_packages(exclude=['examples']),
6-
version = '1.16.10',
6+
version = '1.16.9',
77
license='MIT',
88
description = 'X-Transformers - Pytorch',
99
author = 'Phil Wang',

x_transformers/x_transformers.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -398,16 +398,12 @@ def __init__(
398398
self,
399399
dim,
400400
use_xpos = False,
401-
scale_base = 512,
402-
interpolation_factor = 1.
401+
scale_base = 512
403402
):
404403
super().__init__()
405404
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
406405
self.register_buffer('inv_freq', inv_freq)
407406

408-
assert interpolation_factor >= 1.
409-
self.interpolation_factor = interpolation_factor
410-
411407
if not use_xpos:
412408
self.register_buffer('scale', None)
413409
return
@@ -419,8 +415,6 @@ def __init__(
419415

420416
def forward(self, seq_len, device):
421417
t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
422-
t = t / self.interpolation_factor
423-
424418
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
425419
freqs = torch.cat((freqs, freqs), dim = -1)
426420

@@ -893,7 +887,6 @@ def __init__(
893887
rotary_emb_dim = None,
894888
rotary_xpos = False,
895889
rotary_xpos_scale_base = 512,
896-
rotary_interpolation_factor = 1.,
897890
custom_layers = None,
898891
sandwich_coef = None,
899892
par_ratio = None,
@@ -932,7 +925,7 @@ def __init__(
932925
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
933926

934927
assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
935-
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor) if rotary_pos_emb else None
928+
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base) if rotary_pos_emb else None
936929

937930
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
938931
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'

0 commit comments

Comments
 (0)