@@ -67,18 +67,48 @@ def load_verifier_lm_head(self, verifier_model_name_or_path: str):
67
67
)
68
68
self .verifier_lm_head .weight .data = verifier_lm_head_data [self .t2d_vocab , :]
69
69
70
- def loss_function (self , logits : torch .Tensor , targets : torch .Tensor ):
70
+ def loss_function (
71
+ self ,
72
+ logits : torch .Tensor ,
73
+ targets : torch .Tensor ,
74
+ loss_mask : torch .Tensor ,
75
+ ttt_step : int ,
76
+ ):
77
+ # We don't have target values for the last ttt_step + 1 tokens, so we mask them out on the logit side
78
+ # We shift the target values by ttt_step + 1 to the left because that's the position the generated tokens correspond to
79
+ # e.g.
80
+ # targets_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
81
+ # logits_indices_ttt_step_0 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
82
+ # logits_indices_ttt_step_1 = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
83
+ # logits_indices_ttt_step_2 = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
84
+ # The indices for the loss_mask need to be kept in line with the targets indices
85
+
86
+ # Note: this function is written such that a batch_size > 1 is supported. This is through careful handling of the 0-th "batch dimension"
87
+ # However, currently the 0-th "batch dimension" is always 1 because we are packing the samples together by extending the sequence (1st) dimension.
88
+ # There could be a future use case for batch_size > 1, if pad and stack samples together instead of packing them.
89
+ logits = logits [:, : - (ttt_step + 1 )]
90
+ targets = targets [:, (ttt_step + 1 ) :]
91
+ # logits/targets shape: [batch_size=1, total_seq_len - (ttt_step + 1), draft_vocab_size]
92
+ loss_mask = loss_mask [:, (ttt_step + 1 ) :]
93
+ # loss_mask shape: [batch_size=1, total_seq_len - (ttt_step + 1)]
94
+
71
95
logits = torch .nn .functional .log_softmax (logits , dim = - 1 )
72
96
targets = torch .nn .functional .log_softmax (targets , dim = - 1 )
73
- return torch .nn .functional .kl_div (
74
- logits , targets , reduction = "sum" , log_target = True
75
- ) / (logits .shape [0 ] * logits .shape [1 ])
97
+ kl_div = torch .nn .functional .kl_div (
98
+ logits , targets , reduction = "none" , log_target = True
99
+ )
100
+ masked_kl_div = torch .sum (loss_mask .unsqueeze (- 1 ) * kl_div , dim = (1 , 2 )) / (
101
+ loss_mask .sum (dim = 1 ) + 1e-5
102
+ )
103
+ # shape: [batch_size=1]
104
+ return masked_kl_div .mean ()
76
105
77
106
def forward (
78
107
self ,
79
108
hidden_states : torch .Tensor , # shape: [1, total_seq_len, 3 * hidden_size]
80
109
input_ids : torch .Tensor , # shape: [1, total_seq_len]
81
110
lengths : torch .Tensor | None = None , # shape: [batch_size]
111
+ loss_mask : torch .Tensor | None = None , # shape: [1, total_seq_len]
82
112
verifier_last_hidden_states : torch .Tensor
83
113
| None = None , # shape: [1, total_seq_len, hidden_size]
84
114
ttt_steps : int | None = None ,
@@ -159,9 +189,7 @@ def forward(
159
189
# shape: [1, total_seq_len, draft_vocab_size]
160
190
161
191
if return_loss :
162
- loss += self .loss_function (
163
- logits [:, : - (ttt_step + 1 )], verifier_logits [:, (ttt_step + 1 ) :]
164
- )
192
+ loss += self .loss_function (logits , verifier_logits , loss_mask , ttt_step )
165
193
166
194
input_ids = torch .argmax (logits , dim = - 1 )
167
195
# shape: [1, total_seq_len]
0 commit comments