From 463f36910ffccacde2efba590932069926c4c379 Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Tue, 15 Jul 2025 11:06:43 -0700 Subject: [PATCH] Source transform for HF RoPE in static attention (#12500) Summary: LLama style RoPE shuffles/unshuffles the embedding dimension which is not as efficient as HF style RoPE. Add a source transform on static attention so HF style can be used. Reviewed By: billmguo Differential Revision: D78353775 --- examples/models/llama/static_attention.py | 28 +++++++ .../llama/tests/test_static_attention.py | 74 +++++++++---------- 2 files changed, 61 insertions(+), 41 deletions(-) diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 8f3486353f2..03a9289924e 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -730,6 +730,34 @@ def load_weights_from_attention_mha(self, other: AttentionMHA): self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps) self.k_norm.load_state_dict(other.k_norm_fn.state_dict()) + def adopt_hf_rope(self): + if self.rope.use_hf_rope: + return + + if self.use_conv2d: + raise RuntimeError( + "adopt_hf_rope needs to be called before linear_to_conv2d" + ) + + # Permute weights of qk projections and norms to match HF RoPE's channel order. + def permute(w): + shape = w.shape + return ( + w.view((-1, 2) + shape[1:]).transpose(0, 1).reshape(shape).contiguous() + ) + + for wq in self.wqs: + wq.weight.data.copy_(permute(wq.weight.data)) + + for wk in self.wks: + wk.weight.data.copy_(permute(wk.weight.data)) + + if self.use_qk_norm: + self.q_norm.weight.data.copy_(permute(self.q_norm.weight.data)) + self.k_norm.weight.data.copy_(permute(self.k_norm.weight.data)) + + self.rope.use_hf_rope = True + def linear_to_conv2d(self): def transfer_weight(linear, conv2d): conv2d.weight.data.copy_(linear.weight[:, :, None, None]) diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index 44a483fe981..72a70140a20 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -1,3 +1,4 @@ +import itertools import unittest from collections import defaultdict @@ -19,13 +20,18 @@ def setUp(self): torch.manual_seed(42) def test_without_cache(self): - def test(use_qk_norm, use_conv2d): + def test(use_qk_norm, qk_norm_before_rope, adopt_hf_rope, use_conv2d): + if not use_qk_norm and qk_norm_before_rope: + # Redundant test. + return + config = ModelArgs( dim=64, n_heads=4, n_kv_heads=2, max_seq_len=8, use_qk_norm=use_qk_norm, + qk_norm_before_rope=qk_norm_before_rope, ) layer_id = 0 rope = Rope(config) @@ -40,12 +46,19 @@ def test(use_qk_norm, use_conv2d): torch.rand(config.head_dim) * 0.2 + 0.9 ) static_attn.load_weights_from_attention_mha(attn_mha) + if adopt_hf_rope: + static_attn.adopt_hf_rope() if use_conv2d: static_attn.linear_to_conv2d() x = torch.rand(1, config.max_seq_len, config.dim) freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len) expected, _ = attn_mha(x, freqs_cos, freqs_sin) + + if adopt_hf_rope: + config.use_hf_rope = True + rope = Rope(config) + freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len) mask = torch.triu( torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")), diagonal=1, @@ -56,45 +69,16 @@ def test(use_qk_norm, use_conv2d): freqs_sin, mask=mask, ) - self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all()) - - test(True, True) - test(True, False) - test(False, True) - test(False, False) - - def test_hf_rope_without_cache(self): - config = ModelArgs( - dim=64, - n_heads=4, - n_kv_heads=2, - max_seq_len=8, - use_qk_norm=True, - use_hf_rope=True, - ) - layer_id = 0 - rope = Rope(config) - attn_mha = AttentionMHA(config, layer_id, rope).eval() - with torch.no_grad(): - attn_mha.q_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9) - attn_mha.k_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9) - static_attn = StaticAttention(config, layer_id, rope).eval() - static_attn.load_weights_from_attention_mha(attn_mha) + self.assertTrue( + torch.isclose(y, expected, rtol=1e-3).all(), + f"Failed for use_qk_norm={use_qk_norm}, " + f"qk_norm_before_rope={qk_norm_before_rope}, " + f"adopt_hf_rope={adopt_hf_rope}, " + f"use_conv2d={use_conv2d}", + ) - x = torch.rand(1, config.max_seq_len, config.dim) - freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len) - expected, _ = attn_mha(x, freqs_cos, freqs_sin) - mask = torch.triu( - torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")), - diagonal=1, - ) - y, _ = static_attn( - x, - freqs_cos.unsqueeze(0), - freqs_sin.unsqueeze(0), - mask=mask, - ) - self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all()) + for args in itertools.product([False, True], repeat=4): + test(*args) def test_with_cache(self): config = ModelArgs( @@ -108,6 +92,7 @@ def test_with_cache(self): attn_mha = AttentionMHA(config, layer_id, rope).eval() static_attn = StaticAttention(config, layer_id, rope).eval() static_attn.load_weights_from_attention_mha(attn_mha) + static_attn.adopt_hf_rope() x = torch.rand(1, config.max_seq_len, config.dim) freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len) @@ -117,6 +102,10 @@ def test_with_cache(self): chunk_len = config.max_seq_len // n_chunks cache_len = config.max_seq_len - chunk_len + config.use_hf_rope = True + hf_rope = Rope(config) + hf_freqs_cos, hf_freqs_sin = hf_rope.get_freqs(None, config.max_seq_len) + def test_with_style(style): mask = StaticAttentionMask(chunk_len, cache_len, style=style) mask.tensor[:, :, cache_len:] = torch.triu( @@ -139,8 +128,8 @@ def test_with_style(style): for i in range(n_chunks): y_i, attn_update = static_attn( x[:, i * chunk_len : (i + 1) * chunk_len, :], - freqs_cos[i * chunk_len : (i + 1) * chunk_len], - freqs_sin[i * chunk_len : (i + 1) * chunk_len], + hf_freqs_cos[i * chunk_len : (i + 1) * chunk_len], + hf_freqs_sin[i * chunk_len : (i + 1) * chunk_len], mask=mask.tensor, in_cache_state=(k_caches, v_caches), out_cache_state=({}, {}), @@ -175,6 +164,7 @@ def _get_test_transformers(self, config): mha_transformer.layers, static_transformer.layers ): static_layer.attention.load_weights_from_attention_mha(mha_layer.attention) + static_layer.attention.adopt_hf_rope() return mha_transformer, static_transformer @@ -196,6 +186,7 @@ def test_within_transformer(self): cache_len = config.max_seq_len - chunk_len def test_with_style(style): + config.use_hf_rope = True mgr = StaticAttentionIOManager(config, chunk_len, cache_len, style=style) ys = [] for i in range(n_chunks): @@ -222,6 +213,7 @@ def test_lookahead_decode(self): ) _, static_transformer = self._get_test_transformers(config) + config.use_hf_rope = True input_len = 32 cache_len = config.max_seq_len - input_len prefill_input = torch.randint(config.vocab_size, (input_len,))