From f801ca77ea6fa2186140ec06482d66591ff0e17b Mon Sep 17 00:00:00 2001 From: Aditya Vipradas Date: Sat, 31 May 2025 12:17:45 +0530 Subject: [PATCH 1/3] modified language_model.py added self.training argument to build KV cache only during inference and evaluation --- models/language_model.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/models/language_model.py b/models/language_model.py index f40d3871..e5dde428 100644 --- a/models/language_model.py +++ b/models/language_model.py @@ -124,21 +124,26 @@ def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None): # Apply rotary embeddings to the current q and k q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin) - # Check if we can use cached keys and values - if not is_prefill and block_kv_cache['key'] is not None: - # Concatenate with cached K, V - # k_rotated and v_curr are for the new token(s) - k = block_kv_cache['key'] - v = block_kv_cache['value'] - k = torch.cat([k, k_rotated], dim=2) - v = torch.cat([v, v_curr], dim=2) - block_kv_cache['key'] = k - block_kv_cache['value'] = v + # update KV cache only during inference + if not self.training: + # Check if we can use cached keys and values + if not is_prefill and block_kv_cache['key'] is not None: + # Concatenate with cached K, V + # k_rotated and v_curr are for the new token(s) + k = block_kv_cache['key'] + v = block_kv_cache['value'] + k = torch.cat([k, k_rotated], dim=2) + v = torch.cat([v, v_curr], dim=2) + block_kv_cache['key'] = k + block_kv_cache['value'] = v + else: + # No cache, this is the first pass (prefill) + k = k_rotated + v = v_curr + block_kv_cache = {'key': k, 'value': v} else: - # No cache, this is the first pass (prefill) k = k_rotated v = v_curr - block_kv_cache = {'key': k, 'value': v} # Repeat K, V for Grouped Query Attention k_exp = k.repeat_interleave(self.n_kv_groups, dim=1) # (B, n_heads, T_kv, head_dim) From 0f791f98e643aa26c2a21b679446cf033d49190b Mon Sep 17 00:00:00 2001 From: Aditya Vipradas Date: Sat, 31 May 2025 14:15:26 +0530 Subject: [PATCH 2/3] modified language_model.py 2 replaced self.training with torch.is_grad_enabled() --- models/language_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/language_model.py b/models/language_model.py index e5dde428..7c9a0dc3 100644 --- a/models/language_model.py +++ b/models/language_model.py @@ -125,7 +125,7 @@ def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None): q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin) # update KV cache only during inference - if not self.training: + if not torch.is_grad_enabled(): # Check if we can use cached keys and values if not is_prefill and block_kv_cache['key'] is not None: # Concatenate with cached K, V From c2accf2357462453a01ad955cf864bd965ca65ca Mon Sep 17 00:00:00 2001 From: Aditya Vipradas Date: Sat, 31 May 2025 15:48:49 +0530 Subject: [PATCH 3/3] changed torch.is_grad_enabled() to self.training both torch.is_grad_enabled() and self.training lead to the same KV cache building outcome. --- models/language_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/language_model.py b/models/language_model.py index 7c9a0dc3..3f9fcfc0 100644 --- a/models/language_model.py +++ b/models/language_model.py @@ -124,8 +124,8 @@ def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None): # Apply rotary embeddings to the current q and k q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin) - # update KV cache only during inference - if not torch.is_grad_enabled(): + # update KV cache only during validation and inference + if not self.training: # Check if we can use cached keys and values if not is_prefill and block_kv_cache['key'] is not None: # Concatenate with cached K, V