@@ -398,12 +398,16 @@ def __init__(
398
398
self ,
399
399
dim ,
400
400
use_xpos = False ,
401
- scale_base = 512
401
+ scale_base = 512 ,
402
+ interpolation_factor = 1.
402
403
):
403
404
super ().__init__ ()
404
405
inv_freq = 1. / (10000 ** (torch .arange (0 , dim , 2 ).float () / dim ))
405
406
self .register_buffer ('inv_freq' , inv_freq )
406
407
408
+ assert interpolation_factor >= 1.
409
+ self .interpolation_factor = interpolation_factor
410
+
407
411
if not use_xpos :
408
412
self .register_buffer ('scale' , None )
409
413
return
@@ -415,6 +419,8 @@ def __init__(
415
419
416
420
def forward (self , seq_len , device ):
417
421
t = torch .arange (seq_len , device = device ).type_as (self .inv_freq )
422
+ t = t / self .interpolation_factor
423
+
418
424
freqs = torch .einsum ('i , j -> i j' , t , self .inv_freq )
419
425
freqs = torch .cat ((freqs , freqs ), dim = - 1 )
420
426
@@ -887,6 +893,7 @@ def __init__(
887
893
rotary_emb_dim = None ,
888
894
rotary_xpos = False ,
889
895
rotary_xpos_scale_base = 512 ,
896
+ rotary_interpolation_factor = 1. ,
890
897
custom_layers = None ,
891
898
sandwich_coef = None ,
892
899
par_ratio = None ,
@@ -925,7 +932,7 @@ def __init__(
925
932
rotary_emb_dim = max (default (rotary_emb_dim , dim_head // 2 ), 32 )
926
933
927
934
assert not (rotary_xpos and not causal ), 'rotary xpos is not compatible with bidirectional attention'
928
- self .rotary_pos_emb = RotaryEmbedding (rotary_emb_dim , use_xpos = rotary_xpos , scale_base = rotary_xpos_scale_base ) if rotary_pos_emb else None
935
+ self .rotary_pos_emb = RotaryEmbedding (rotary_emb_dim , use_xpos = rotary_xpos , scale_base = rotary_xpos_scale_base , interpolation_factor = rotary_interpolation_factor ) if rotary_pos_emb else None
929
936
930
937
assert not (alibi_pos_bias and rel_pos_bias ), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
931
938
assert rel_pos_num_buckets <= rel_pos_max_distance , 'number of relative position buckets must be less than the relative position max distance'
0 commit comments