-
Notifications
You must be signed in to change notification settings - Fork 12
Instruction Tuning Improvements #380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a draft.
What has been done so far:
-
Refactored the collator design which is now composable. Meaning you can call multiple collate_fns sequentially. This is helpful for instance for instruction tuning, where you first shift the targets for the autoregressive objective and then mask certiain tokens to disregard in the loss such as input from the user role. In this case, we would have two collate_fns being called one after the other.
-
We can specify now, which component to load in the app_state from a DCP checkpoint. This is important for continued pretraining or finetuning, where we don't want to use the previous optimizer and lr scheduler states.
-
Added an iterative mem map dataset implemenation. Instead of packing the samples, we take the index (stored in the pbin) to iterate over each sample.
Todos:
- Add tests for the new functionality
- Fix failing tests.
- Verify correctness of the code.
- Sometimes we skip samples by removing masking out all tokens in the loss (setting token_ids to -100). When the whole batch only consists of "skipped" samples, the loss is NaN. Especially, in case of batch size 1 this can be seen quite frequently. We need to find a solution to not run backprop on such batches.
The current plan is to make the instruction-tuning data preparation independent towards special tokens, as #383 persists. This works by registering new tags to the jinja Template Render Environment (see the linked code above for the full code): class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"} Within the chat template this tag is started Then, during rendering of the chat template with tokenize=True and return_assistant_mask=True the indices during which the "generation" tag was active are returned: We can do the same and remove the old special-token-based approach. However, this is not compatible to the current pbin file, as we only store the IDs and by now not other fields (attention_mask, assistant_mask). We would need to have tokenization and packing on the fly for this jinja-based assistant turn tracking. Packing on the fly in TRL: https://github.com/huggingface/trl/blob/main/trl/data_utils.py#L495 |
What does this PR do?
This PR ..
General Changes
Breaking Changes
Checklist before submitting final PR
python tests/tests.py
)CHANGELOG_DEV.md
)