@@ -302,7 +302,7 @@ def __init__(
302302 inv_freq = 1.0 / (
303303 base ** (torch .arange (0 , dim , 2 , device = "cuda" , dtype = torch .float32 ) / dim )
304304 )
305- self .register_buffer ("inv_freq" , inv_freq . to ( dtype ) , persistent = False )
305+ self .register_buffer ("inv_freq" , inv_freq , persistent = False )
306306
307307 self ._seq_len_cached = - 1
308308 self ._cos_cached = None
@@ -333,14 +333,14 @@ def _update_cos_sin_tables(self, x, seq_dim):
333333 freqs = torch .outer (t * self .distance_scale , self .inv_freq )
334334 emb = torch .cat ((freqs , freqs ), dim = - 1 )
335335 if x .dim () == 2 :
336- self ._cos_cached = emb .cos ()
337- self ._sin_cached = emb .sin ()
336+ self ._cos_cached = emb .cos (). to ( self . dtype )
337+ self ._sin_cached = emb .sin (). to ( self . dtype )
338338 elif x .dim () == 3 :
339- self ._cos_cached = emb .cos ()[None , :, :]
340- self ._sin_cached = emb .sin ()[None , :, :]
339+ self ._cos_cached = emb .cos ()[None , :, :]. to ( self . dtype )
340+ self ._sin_cached = emb .sin ()[None , :, :]. to ( self . dtype )
341341 elif x .dim () == 4 :
342- self ._cos_cached = emb .cos ()[None , None , :, :]
343- self ._sin_cached = emb .sin ()[None , None , :, :]
342+ self ._cos_cached = emb .cos ()[None , None , :, :]. to ( self . dtype )
343+ self ._sin_cached = emb .sin ()[None , None , :, :]. to ( self . dtype )
344344 return self ._cos_cached , self ._sin_cached
345345
346346 def forward (self , q : torch .Tensor , k : torch .Tensor , seq_dim = - 2 ) -> Tuple [torch .Tensor , torch .Tensor ]:
0 commit comments