Skip to content

Conversation

le1nux
Copy link
Member

@le1nux le1nux commented Jul 3, 2025

What does this PR do?

This PR ..

General Changes

  • ..

Breaking Changes

  • ..

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

Copy link
Member Author

@le1nux le1nux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a draft.

What has been done so far:

  • Refactored the collator design which is now composable. Meaning you can call multiple collate_fns sequentially. This is helpful for instance for instruction tuning, where you first shift the targets for the autoregressive objective and then mask certiain tokens to disregard in the loss such as input from the user role. In this case, we would have two collate_fns being called one after the other.

  • We can specify now, which component to load in the app_state from a DCP checkpoint. This is important for continued pretraining or finetuning, where we don't want to use the previous optimizer and lr scheduler states.

  • Added an iterative mem map dataset implemenation. Instead of packing the samples, we take the index (stored in the pbin) to iterate over each sample.

Todos:

  • Add tests for the new functionality
  • Fix failing tests.
  • Verify correctness of the code.
  • Sometimes we skip samples by removing masking out all tokens in the loss (setting token_ids to -100). When the whole batch only consists of "skipped" samples, the loss is NaN. Especially, in case of batch size 1 this can be seen quite frequently. We need to find a solution to not run backprop on such batches.

@le1nux le1nux changed the title App state refactoring Instruction Tuning Improvements Jul 6, 2025
@lllAlexanderlll
Copy link
Contributor

lllAlexanderlll commented Aug 18, 2025

The current plan is to make the instruction-tuning data preparation independent towards special tokens, as #383 persists.
This can be done by using jinjas Extension feature, which is used in Huggingfaces newest iteration of handling chat templates:
https://github.com/huggingface/transformers/blob/v4.53.3/src/transformers/utils/chat_template_utils.py#L375

This works by registering new tags to the jinja Template Render Environment (see the linked code above for the full code):

    class AssistantTracker(Extension):
        # This extension is used to track the indices of assistant-generated tokens in the rendered chat
        tags = {"generation"}

Within the chat template this tag is started
https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76
and ended
https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L82
for assistant turns.

Then, during rendering of the chat template with tokenize=True and return_assistant_mask=True the indices during which the "generation" tag was active are returned:
https://github.com/huggingface/transformers/blob/v4.53.3/src/transformers/utils/chat_template_utils.py#L442
https://github.com/huggingface/transformers/blob/v4.53.3/src/transformers/processing_utils.py#L1629
and added as assistant_mask:
https://github.com/huggingface/transformers/blob/v4.53.3/src/transformers/processing_utils.py#L1678

We can do the same and remove the old special-token-based approach. However, this is not compatible to the current pbin file, as we only store the IDs and by now not other fields (attention_mask, assistant_mask). We would need to have tokenization and packing on the fly for this jinja-based assistant turn tracking.

Packing on the fly in TRL: https://github.com/huggingface/trl/blob/main/trl/data_utils.py#L495

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants