@@ -398,16 +398,12 @@ def __init__(
398
398
self ,
399
399
dim ,
400
400
use_xpos = False ,
401
- scale_base = 512 ,
402
- interpolation_factor = 1.
401
+ scale_base = 512
403
402
):
404
403
super ().__init__ ()
405
404
inv_freq = 1. / (10000 ** (torch .arange (0 , dim , 2 ).float () / dim ))
406
405
self .register_buffer ('inv_freq' , inv_freq )
407
406
408
- assert interpolation_factor >= 1.
409
- self .interpolation_factor = interpolation_factor
410
-
411
407
if not use_xpos :
412
408
self .register_buffer ('scale' , None )
413
409
return
@@ -419,8 +415,6 @@ def __init__(
419
415
420
416
def forward (self , seq_len , device ):
421
417
t = torch .arange (seq_len , device = device ).type_as (self .inv_freq )
422
- t = t / self .interpolation_factor
423
-
424
418
freqs = torch .einsum ('i , j -> i j' , t , self .inv_freq )
425
419
freqs = torch .cat ((freqs , freqs ), dim = - 1 )
426
420
@@ -893,7 +887,6 @@ def __init__(
893
887
rotary_emb_dim = None ,
894
888
rotary_xpos = False ,
895
889
rotary_xpos_scale_base = 512 ,
896
- rotary_interpolation_factor = 1. ,
897
890
custom_layers = None ,
898
891
sandwich_coef = None ,
899
892
par_ratio = None ,
@@ -932,7 +925,7 @@ def __init__(
932
925
rotary_emb_dim = max (default (rotary_emb_dim , dim_head // 2 ), 32 )
933
926
934
927
assert not (rotary_xpos and not causal ), 'rotary xpos is not compatible with bidirectional attention'
935
- self .rotary_pos_emb = RotaryEmbedding (rotary_emb_dim , use_xpos = rotary_xpos , scale_base = rotary_xpos_scale_base , interpolation_factor = rotary_interpolation_factor ) if rotary_pos_emb else None
928
+ self .rotary_pos_emb = RotaryEmbedding (rotary_emb_dim , use_xpos = rotary_xpos , scale_base = rotary_xpos_scale_base ) if rotary_pos_emb else None
936
929
937
930
assert not (alibi_pos_bias and rel_pos_bias ), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
938
931
assert rel_pos_num_buckets <= rel_pos_max_distance , 'number of relative position buckets must be less than the relative position max distance'
0 commit comments