Skip to content

Commit 1449045

Browse files
authored
Mixed Precision RoPE
1 parent d084767 commit 1449045

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

model_center/layer/position_embedding.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def __init__(
302302
inv_freq = 1.0 / (
303303
base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim)
304304
)
305-
self.register_buffer("inv_freq", inv_freq.to(dtype), persistent=False)
305+
self.register_buffer("inv_freq", inv_freq, persistent=False)
306306

307307
self._seq_len_cached = -1
308308
self._cos_cached = None
@@ -333,14 +333,14 @@ def _update_cos_sin_tables(self, x, seq_dim):
333333
freqs = torch.outer(t * self.distance_scale, self.inv_freq)
334334
emb = torch.cat((freqs, freqs), dim = -1)
335335
if x.dim() == 2:
336-
self._cos_cached = emb.cos()
337-
self._sin_cached = emb.sin()
336+
self._cos_cached = emb.cos().to(self.dtype)
337+
self._sin_cached = emb.sin().to(self.dtype)
338338
elif x.dim() == 3:
339-
self._cos_cached = emb.cos()[None, :, :]
340-
self._sin_cached = emb.sin()[None, :, :]
339+
self._cos_cached = emb.cos()[None, :, :].to(self.dtype)
340+
self._sin_cached = emb.sin()[None, :, :].to(self.dtype)
341341
elif x.dim() == 4:
342-
self._cos_cached = emb.cos()[None, None, :, :]
343-
self._sin_cached = emb.sin()[None, None, :, :]
342+
self._cos_cached = emb.cos()[None, None, :, :].to(self.dtype)
343+
self._sin_cached = emb.sin()[None, None, :, :].to(self.dtype)
344344
return self._cos_cached, self._sin_cached
345345

346346
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dim= -2) -> Tuple[torch.Tensor, torch.Tensor]:

0 commit comments

Comments
 (0)