Skip to content

Commit 6c66033

Browse files
committed
revert loss-masking change
1 parent cbc94e0 commit 6c66033

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

fast_llm/functional/cross_entropy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def _fused_cross_entropy_forward_backward(
160160

161161
per_sample_loss = sum_exp_logits.log() - predicted_logits
162162
if loss_mask is not None:
163-
per_sample_loss = per_sample_loss[loss_mask]
163+
per_sample_loss = per_sample_loss * loss_mask
164164

165165
loss = per_sample_loss.mean()
166166
if target_format != TargetFormat.labels and group is not None:
@@ -322,7 +322,7 @@ def _torch_reverse_kl_forward_backward(
322322
# Clamp to prevent extreme values that cause NaNs in log_softmax
323323
scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0)
324324
student_log_probs = torch.log_softmax(scaled_logits, dim=-1)
325-
325+
326326
# Reverse KL: input=teacher_log_probs, target=student_probs
327327
if loss_mask is None:
328328
loss = torch.nn.functional.kl_div(

0 commit comments

Comments
 (0)