From 2813aae72722e9b6899e29de86c6001369d893b0 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 2 Jul 2025 11:10:34 +0200 Subject: [PATCH 1/2] Make _compute_dynamic_ntk_parameters exportable --- src/transformers/modeling_rope_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index e84c2c4a792e..51ee6323fab3 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -215,7 +215,14 @@ def _compute_dynamic_ntk_parameters( attention_factor = 1.0 # Unused in this type of RoPE # seq_len: default to max_position_embeddings, e.g. at init time - seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings + if seq_len is None: + seq_len = max_position_embeddings + else: + torch._check(isinstance(seq_len, torch.Tensor)) + seq_len = torch.maximum( + seq_len, + torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device), + ) # Compute the inverse frequencies base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) From b5f788a2c7c4f9bee10f03e066cf469538dad7c8 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 2 Jul 2025 12:14:36 +0200 Subject: [PATCH 2/2] add unit test --- src/transformers/modeling_rope_utils.py | 5 +++-- tests/utils/test_modeling_rope_utils.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 51ee6323fab3..4786cce27356 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -217,12 +217,13 @@ def _compute_dynamic_ntk_parameters( # seq_len: default to max_position_embeddings, e.g. at init time if seq_len is None: seq_len = max_position_embeddings - else: - torch._check(isinstance(seq_len, torch.Tensor)) + elif isinstance(seq_len, torch.Tensor): seq_len = torch.maximum( seq_len, torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device), ) + else: + seq_len = max(seq_len, max_position_embeddings) # Compute the inverse frequencies base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index fd9f5887b6c4..761a785f369f 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -220,6 +220,9 @@ def test_dynamic_rope_numerically(self): inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=1) torch.testing.assert_close(inv_freq, default_inv_freq) + inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=torch.tensor(1, dtype=torch.int64)) + torch.testing.assert_close(inv_freq, default_inv_freq) + # Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies # will scale up (i.e., the inverse frequencies will scale down). factor = 10.0