Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,21 +124,26 @@ def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None):
# Apply rotary embeddings to the current q and k
q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin)

# Check if we can use cached keys and values
if not is_prefill and block_kv_cache['key'] is not None:
# Concatenate with cached K, V
# k_rotated and v_curr are for the new token(s)
k = block_kv_cache['key']
v = block_kv_cache['value']
k = torch.cat([k, k_rotated], dim=2)
v = torch.cat([v, v_curr], dim=2)
block_kv_cache['key'] = k
block_kv_cache['value'] = v
# update KV cache only during validation and inference
if not self.training:
# Check if we can use cached keys and values
if not is_prefill and block_kv_cache['key'] is not None:
# Concatenate with cached K, V
# k_rotated and v_curr are for the new token(s)
k = block_kv_cache['key']
v = block_kv_cache['value']
k = torch.cat([k, k_rotated], dim=2)
v = torch.cat([v, v_curr], dim=2)
block_kv_cache['key'] = k
block_kv_cache['value'] = v
else:
# No cache, this is the first pass (prefill)
k = k_rotated
v = v_curr
block_kv_cache = {'key': k, 'value': v}
else:
# No cache, this is the first pass (prefill)
k = k_rotated
v = v_curr
block_kv_cache = {'key': k, 'value': v}

# Repeat K, V for Grouped Query Attention
k_exp = k.repeat_interleave(self.n_kv_groups, dim=1) # (B, n_heads, T_kv, head_dim)
Expand Down