davda54 commited on
Commit
30fd0e7
·
verified ·
1 Parent(s): 39265fc
Files changed (1) hide show
  1. modeling_gptbert.py +0 -3
modeling_gptbert.py CHANGED
@@ -243,9 +243,6 @@ class UnpaddedRotaryEmbedding(RotaryEmbedding):
243
  super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=None, interleaved=False)
244
  self.max_seqlen = max_seqlen
245
 
246
- if max_seqlen is not None and device is not None and dtype is not None:
247
- self._update_cos_sin_cache(max_seqlen, device=device, dtype=None)
248
-
249
  def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
250
  if max_seqlen is not None:
251
  self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
 
243
  super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=None, interleaved=False)
244
  self.max_seqlen = max_seqlen
245
 
 
 
 
246
  def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
247
  if max_seqlen is not None:
248
  self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)