File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -160,7 +160,7 @@ def _fused_cross_entropy_forward_backward(
160160
161161 per_sample_loss = sum_exp_logits .log () - predicted_logits
162162 if loss_mask is not None :
163- per_sample_loss = per_sample_loss [ loss_mask ]
163+ per_sample_loss = per_sample_loss * loss_mask
164164
165165 loss = per_sample_loss .mean ()
166166 if target_format != TargetFormat .labels and group is not None :
@@ -322,7 +322,7 @@ def _torch_reverse_kl_forward_backward(
322322 # Clamp to prevent extreme values that cause NaNs in log_softmax
323323 scaled_logits = torch .clamp (scaled_logits , min = - 100.0 , max = 100.0 )
324324 student_log_probs = torch .log_softmax (scaled_logits , dim = - 1 )
325-
325+
326326 # Reverse KL: input=teacher_log_probs, target=student_probs
327327 if loss_mask is None :
328328 loss = torch .nn .functional .kl_div (
You can’t perform that action at this time.
0 commit comments