Skip to content

Commit 1e92e40

Browse files
committed
Add loss masking
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent e5ef044 commit 1e92e40

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

src/speculators/train/eagle3/core.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,48 @@ def load_verifier_lm_head(self, verifier_model_name_or_path: str):
6767
)
6868
self.verifier_lm_head.weight.data = verifier_lm_head_data[self.t2d_vocab, :]
6969

70-
def loss_function(self, logits: torch.Tensor, targets: torch.Tensor):
70+
def loss_function(
71+
self,
72+
logits: torch.Tensor,
73+
targets: torch.Tensor,
74+
loss_mask: torch.Tensor,
75+
ttt_step: int,
76+
):
77+
# We don't have target values for the last ttt_step + 1 tokens, so we mask them out on the logit side
78+
# We shift the target values by ttt_step + 1 to the left because that's the position the generated tokens correspond to
79+
# e.g.
80+
# targets_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
81+
# logits_indices_ttt_step_0 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
82+
# logits_indices_ttt_step_1 = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
83+
# logits_indices_ttt_step_2 = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
84+
# The indices for the loss_mask need to be kept in line with the targets indices
85+
86+
# Note: this function is written such that a batch_size > 1 is supported. This is through careful handling of the 0-th "batch dimension"
87+
# However, currently the 0-th "batch dimension" is always 1 because we are packing the samples together by extending the sequence (1st) dimension.
88+
# There could be a future use case for batch_size > 1, if pad and stack samples together instead of packing them.
89+
logits = logits[:, : -(ttt_step + 1)]
90+
targets = targets[:, (ttt_step + 1) :]
91+
# logits/targets shape: [batch_size=1, total_seq_len - (ttt_step + 1), draft_vocab_size]
92+
loss_mask = loss_mask[:, (ttt_step + 1) :]
93+
# loss_mask shape: [batch_size=1, total_seq_len - (ttt_step + 1)]
94+
7195
logits = torch.nn.functional.log_softmax(logits, dim=-1)
7296
targets = torch.nn.functional.log_softmax(targets, dim=-1)
73-
return torch.nn.functional.kl_div(
74-
logits, targets, reduction="sum", log_target=True
75-
) / (logits.shape[0] * logits.shape[1])
97+
kl_div = torch.nn.functional.kl_div(
98+
logits, targets, reduction="none", log_target=True
99+
)
100+
masked_kl_div = torch.sum(loss_mask.unsqueeze(-1) * kl_div, dim=(1, 2)) / (
101+
loss_mask.sum(dim=1) + 1e-5
102+
)
103+
# shape: [batch_size=1]
104+
return masked_kl_div.mean()
76105

77106
def forward(
78107
self,
79108
hidden_states: torch.Tensor, # shape: [1, total_seq_len, 3 * hidden_size]
80109
input_ids: torch.Tensor, # shape: [1, total_seq_len]
81110
lengths: torch.Tensor | None = None, # shape: [batch_size]
111+
loss_mask: torch.Tensor | None = None, # shape: [1, total_seq_len]
82112
verifier_last_hidden_states: torch.Tensor
83113
| None = None, # shape: [1, total_seq_len, hidden_size]
84114
ttt_steps: int | None = None,
@@ -159,9 +189,7 @@ def forward(
159189
# shape: [1, total_seq_len, draft_vocab_size]
160190

161191
if return_loss:
162-
loss += self.loss_function(
163-
logits[:, : -(ttt_step + 1)], verifier_logits[:, (ttt_step + 1) :]
164-
)
192+
loss += self.loss_function(logits, verifier_logits, loss_mask, ttt_step)
165193

166194
input_ids = torch.argmax(logits, dim=-1)
167195
# shape: [1, total_seq_len]

src/speculators/train/training_loop.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
DEVICE = "cuda:0"
1111
EPOCHS = 10
12-
draft_vocab_size = 4096
12+
draft_vocab_size = 5000
1313
verifier_vocab_size = 151936
1414
hidden_size = 5120
15-
total_seq_len = 4096
15+
total_seq_len = 2048
1616
datapath = "./data"
1717
verifier_model_name_or_path = "Qwen/Qwen2.5-VL-7B-Instruct"
1818

@@ -75,7 +75,6 @@ def train_epoch(
7575
batch = optree.tree_map(
7676
lambda x: x.to(local_rank) if isinstance(x, torch.Tensor) else x, batch
7777
)
78-
del batch["loss_mask"]
7978

8079
_, loss = model(**batch, use_off_policy_tokens=True)
8180
print(loss.item())

0 commit comments

Comments
 (0)