Skip to content

Source transform for HF RoPE in static attention #12500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,34 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps)
self.k_norm.load_state_dict(other.k_norm_fn.state_dict())

def adopt_hf_rope(self):
if self.rope.use_hf_rope:
return

if self.use_conv2d:
raise RuntimeError(
"adopt_hf_rope needs to be called before linear_to_conv2d"
)

# Permute weights of qk projections and norms to match HF RoPE's channel order.
def permute(w):
shape = w.shape
return (
w.view((-1, 2) + shape[1:]).transpose(0, 1).reshape(shape).contiguous()
)

for wq in self.wqs:
wq.weight.data.copy_(permute(wq.weight.data))

for wk in self.wks:
wk.weight.data.copy_(permute(wk.weight.data))

if self.use_qk_norm:
self.q_norm.weight.data.copy_(permute(self.q_norm.weight.data))
self.k_norm.weight.data.copy_(permute(self.k_norm.weight.data))

self.rope.use_hf_rope = True

def linear_to_conv2d(self):
def transfer_weight(linear, conv2d):
conv2d.weight.data.copy_(linear.weight[:, :, None, None])
Expand Down
74 changes: 33 additions & 41 deletions examples/models/llama/tests/test_static_attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import unittest
from collections import defaultdict

Expand All @@ -19,13 +20,18 @@ def setUp(self):
torch.manual_seed(42)

def test_without_cache(self):
def test(use_qk_norm, use_conv2d):
def test(use_qk_norm, qk_norm_before_rope, adopt_hf_rope, use_conv2d):
if not use_qk_norm and qk_norm_before_rope:
# Redundant test.
return

config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=8,
use_qk_norm=use_qk_norm,
qk_norm_before_rope=qk_norm_before_rope,
)
layer_id = 0
rope = Rope(config)
Expand All @@ -40,12 +46,19 @@ def test(use_qk_norm, use_conv2d):
torch.rand(config.head_dim) * 0.2 + 0.9
)
static_attn.load_weights_from_attention_mha(attn_mha)
if adopt_hf_rope:
static_attn.adopt_hf_rope()
if use_conv2d:
static_attn.linear_to_conv2d()

x = torch.rand(1, config.max_seq_len, config.dim)
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
expected, _ = attn_mha(x, freqs_cos, freqs_sin)

if adopt_hf_rope:
config.use_hf_rope = True
rope = Rope(config)
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
mask = torch.triu(
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
diagonal=1,
Expand All @@ -56,45 +69,16 @@ def test(use_qk_norm, use_conv2d):
freqs_sin,
mask=mask,
)
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())

test(True, True)
test(True, False)
test(False, True)
test(False, False)

def test_hf_rope_without_cache(self):
config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=8,
use_qk_norm=True,
use_hf_rope=True,
)
layer_id = 0
rope = Rope(config)
attn_mha = AttentionMHA(config, layer_id, rope).eval()
with torch.no_grad():
attn_mha.q_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9)
attn_mha.k_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9)
static_attn = StaticAttention(config, layer_id, rope).eval()
static_attn.load_weights_from_attention_mha(attn_mha)
self.assertTrue(
torch.isclose(y, expected, rtol=1e-3).all(),
f"Failed for use_qk_norm={use_qk_norm}, "
f"qk_norm_before_rope={qk_norm_before_rope}, "
f"adopt_hf_rope={adopt_hf_rope}, "
f"use_conv2d={use_conv2d}",
)

x = torch.rand(1, config.max_seq_len, config.dim)
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
mask = torch.triu(
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
diagonal=1,
)
y, _ = static_attn(
x,
freqs_cos.unsqueeze(0),
freqs_sin.unsqueeze(0),
mask=mask,
)
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
for args in itertools.product([False, True], repeat=4):
test(*args)

def test_with_cache(self):
config = ModelArgs(
Expand All @@ -108,6 +92,7 @@ def test_with_cache(self):
attn_mha = AttentionMHA(config, layer_id, rope).eval()
static_attn = StaticAttention(config, layer_id, rope).eval()
static_attn.load_weights_from_attention_mha(attn_mha)
static_attn.adopt_hf_rope()

x = torch.rand(1, config.max_seq_len, config.dim)
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
Expand All @@ -117,6 +102,10 @@ def test_with_cache(self):
chunk_len = config.max_seq_len // n_chunks
cache_len = config.max_seq_len - chunk_len

config.use_hf_rope = True
hf_rope = Rope(config)
hf_freqs_cos, hf_freqs_sin = hf_rope.get_freqs(None, config.max_seq_len)

def test_with_style(style):
mask = StaticAttentionMask(chunk_len, cache_len, style=style)
mask.tensor[:, :, cache_len:] = torch.triu(
Expand All @@ -139,8 +128,8 @@ def test_with_style(style):
for i in range(n_chunks):
y_i, attn_update = static_attn(
x[:, i * chunk_len : (i + 1) * chunk_len, :],
freqs_cos[i * chunk_len : (i + 1) * chunk_len],
freqs_sin[i * chunk_len : (i + 1) * chunk_len],
hf_freqs_cos[i * chunk_len : (i + 1) * chunk_len],
hf_freqs_sin[i * chunk_len : (i + 1) * chunk_len],
mask=mask.tensor,
in_cache_state=(k_caches, v_caches),
out_cache_state=({}, {}),
Expand Down Expand Up @@ -175,6 +164,7 @@ def _get_test_transformers(self, config):
mha_transformer.layers, static_transformer.layers
):
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)
static_layer.attention.adopt_hf_rope()

return mha_transformer, static_transformer

Expand All @@ -196,6 +186,7 @@ def test_within_transformer(self):
cache_len = config.max_seq_len - chunk_len

def test_with_style(style):
config.use_hf_rope = True
mgr = StaticAttentionIOManager(config, chunk_len, cache_len, style=style)
ys = []
for i in range(n_chunks):
Expand All @@ -222,6 +213,7 @@ def test_lookahead_decode(self):
)
_, static_transformer = self._get_test_transformers(config)

config.use_hf_rope = True
input_len = 32
cache_len = config.max_seq_len - input_len
prefill_input = torch.randint(config.vocab_size, (input_len,))
Expand Down
Loading