Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 26 additions & 19 deletions models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, cfg):
if not self.sdpa:
print("Warning: scaled dot product attention not available, using standard attention in LM.")

def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None):
def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None, is_training=True):
is_prefill = block_kv_cache is None

B, T_curr, C = x.size() # T_curr is the sequence length of the current input x
Expand All @@ -124,21 +124,26 @@ def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None):
# Apply rotary embeddings to the current q and k
q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin)

# Check if we can use cached keys and values
if not is_prefill and block_kv_cache['key'] is not None:
# Concatenate with cached K, V
# k_rotated and v_curr are for the new token(s)
k = block_kv_cache['key']
v = block_kv_cache['value']
k = torch.cat([k, k_rotated], dim=2)
v = torch.cat([v, v_curr], dim=2)
block_kv_cache['key'] = k
block_kv_cache['value'] = v
# create and update block_kv_cache only during inference
if not is_training:
# Check if we can use cached keys and values
if not is_prefill and block_kv_cache['key'] is not None:
# Concatenate with cached K, V
# k_rotated and v_curr are for the new token(s)
k = block_kv_cache['key']
v = block_kv_cache['value']
k = torch.cat([k, k_rotated], dim=2)
v = torch.cat([v, v_curr], dim=2)
block_kv_cache['key'] = k
block_kv_cache['value'] = v
else:
# No cache, this is the first pass (prefill)
k = k_rotated
v = v_curr
block_kv_cache = {'key': k, 'value': v}
else:
# No cache, this is the first pass (prefill)
k = k_rotated
v = v_curr
block_kv_cache = {'key': k, 'value': v}

# Repeat K, V for Grouped Query Attention
k_exp = k.repeat_interleave(self.n_kv_groups, dim=1) # (B, n_heads, T_kv, head_dim)
Expand Down Expand Up @@ -215,10 +220,10 @@ def __init__(self, cfg):
self.norm1 = RMSNorm(cfg) # Input Norm
self.norm2 = RMSNorm(cfg) # Post Attention Norm

def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None):
def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None, is_training=True):
res = x
x = self.norm1(x)
x, block_kv_cache = self.attn(x, cos, sin, attention_mask, block_kv_cache)
x, block_kv_cache = self.attn(x, cos, sin, attention_mask, block_kv_cache, is_training)
x = res + x

res = x
Expand Down Expand Up @@ -258,7 +263,7 @@ def _init_weights(self, module):
elif isinstance(module, RMSNorm):
module.weight.data.fill_(1.0)

def forward(self, x, attention_mask=None, kv_cache=None, start_pos=0):
def forward(self, x, attention_mask=None, kv_cache=None, start_pos=0, is_training=True):
if self.lm_use_tokens:
x = self.token_embedding(x)

Expand All @@ -274,7 +279,7 @@ def forward(self, x, attention_mask=None, kv_cache=None, start_pos=0):
kv_cache = [None] * len(self.blocks)

for i, block in enumerate(self.blocks):
x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i])
x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i], is_training)

x = self.norm(x)

Expand All @@ -296,7 +301,8 @@ def generate(self, inputs, max_new_tokens=20):
generated_outputs,
attention_mask=None,
kv_cache=None,
start_pos=0
start_pos=0,
is_training=False
)
last_output = prompt_output[:, -1, :]

Expand All @@ -321,7 +327,8 @@ def generate(self, inputs, max_new_tokens=20):
next_output,
attention_mask=None,
kv_cache=kv_cache_list,
start_pos=current_token_start_pos
start_pos=current_token_start_pos,
is_training=False
)
last_output = decode_step_output[:, -1, :]

Expand Down
10 changes: 6 additions & 4 deletions models/vision_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, cfg: VLMConfig, load_backbone=True):
self.MP = ModalityProjector(cfg)
self.load_backbone = load_backbone

def forward(self, input_ids, image, attention_mask=None, targets=None):
def forward(self, input_ids, image, attention_mask=None, targets=None, is_training=True):
image_embd = self.vision_encoder(image)
image_embd = self.MP(image_embd)

Expand All @@ -48,7 +48,7 @@ def forward(self, input_ids, image, attention_mask=None, targets=None):
# Combine image and token attention masks
attention_mask = torch.cat((image_attention_mask, attention_mask), dim=1)

logits, _ = self.decoder(combined_embd, attention_mask=attention_mask) # Not logits yet, but easier to return like this
logits, _ = self.decoder(combined_embd, attention_mask=attention_mask, is_training=is_training) # Not logits yet, but easier to return like this

loss = None
if targets is not None:
Expand Down Expand Up @@ -87,7 +87,8 @@ def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_
initial_combined_embeds,
attention_mask=attention_mask,
kv_cache=None,
start_pos=0
start_pos=0,
is_training=False
)

last_token_output_from_prefill = prefill_output[:, -1, :]
Expand Down Expand Up @@ -127,7 +128,8 @@ def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5, top_
next_token_embed,
attention_mask=attention_mask,
kv_cache=kv_cache_list,
start_pos=current_token_start_pos
start_pos=current_token_start_pos,
is_training=False
)

last_token_output = decode_step_output[:, -1, :]
Expand Down