diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index e84c2c4a792e..4786cce27356 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -215,7 +215,15 @@ def _compute_dynamic_ntk_parameters( attention_factor = 1.0 # Unused in this type of RoPE # seq_len: default to max_position_embeddings, e.g. at init time - seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings + if seq_len is None: + seq_len = max_position_embeddings + elif isinstance(seq_len, torch.Tensor): + seq_len = torch.maximum( + seq_len, + torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device), + ) + else: + seq_len = max(seq_len, max_position_embeddings) # Compute the inverse frequencies base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index fd9f5887b6c4..761a785f369f 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -220,6 +220,9 @@ def test_dynamic_rope_numerically(self): inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=1) torch.testing.assert_close(inv_freq, default_inv_freq) + inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=torch.tensor(1, dtype=torch.int64)) + torch.testing.assert_close(inv_freq, default_inv_freq) + # Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies # will scale up (i.e., the inverse frequencies will scale down). factor = 10.0