Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions test_vlm_text_only_issue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Test for issue #3957 - VLM KeyError fix"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be deleted before merging


from unittest.mock import MagicMock
import torch
from trl.trainer.sft_trainer import DataCollatorForVisionLanguageModeling


def test_collator():
processor = MagicMock()
processor.apply_chat_template = MagicMock(return_value=["test"])
processor.return_value = {"input_ids": torch.tensor([[1, 2, 3]]), "attention_mask": torch.tensor([[1, 1, 1]])}

collator = DataCollatorForVisionLanguageModeling(processor=processor)

# Test with images
examples = [{"images": ["img"], "messages": [{"role": "user", "content": "test"}]}]
collator(examples)
assert "images" in processor.call_args.kwargs

# Test without images (failed before the fix)
processor.reset_mock()
processor.return_value = {"input_ids": torch.tensor([[1, 2, 3]]), "attention_mask": torch.tensor([[1, 1, 1]])}
examples = [{"messages": [{"role": "user", "content": "test"}]}]
collator(examples)
assert "images" not in processor.call_args.kwargs

print("Tests passed")


if __name__ == "__main__":
test_collator()
68 changes: 45 additions & 23 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,17 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
raise KeyError(f"Unexpected input keys in examples: {list(examples[0].keys())}.")

def _collate_language_modeling(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
images = [example["images"] for example in examples]
# Check if examples contain images - some VLMs can be used for text-only tasks
has_images = "images" in examples[0]
if has_images:
images = [example["images"] for example in examples]
else:
images = None

if "messages" in examples[0]: # conversational case
for example in examples:
prepare_multimodal_messages(example["messages"], len(example["images"]))
num_images = len(example["images"]) if has_images else 0
prepare_multimodal_messages(example["messages"], num_images)
messages = [example["messages"] for example in examples]
texts = self.processor.apply_chat_template(messages)
elif self.dataset_text_field in examples[0]: # standard case
Expand All @@ -399,17 +405,21 @@ def _collate_language_modeling(self, examples: list[Union[list[int], Any, dict[s
"data."
)

output = self.processor(
images=images,
text=texts,
padding=True,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
truncation=self.max_length is not None,
max_length=self.max_length,
return_tensors=self.return_tensors,
add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
)
# For text-only data with VLM models, don't pass images to the processor
processor_kwargs = {
"text": texts,
"padding": True,
"padding_side": "right",
"pad_to_multiple_of": self.pad_to_multiple_of,
"truncation": self.max_length is not None,
"max_length": self.max_length,
"return_tensors": self.return_tensors,
"add_special_tokens": False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
}
if has_images:
processor_kwargs["images"] = images

output = self.processor(**processor_kwargs)
labels = output["input_ids"].clone()
labels[output["attention_mask"] == 0] = -100
# We mask only padding tokens (-100) in the labels. Vision tokens are left unchanged because their handling in
Expand All @@ -424,23 +434,35 @@ def _collate_prompt_completion(self, examples: list[Union[list[int], Any, dict[s
"Padding to a multiple of a value is not yet implemented for vision-language modeling and "
"prompt-completion data yet."
)
images = [example["images"] for example in examples]

# Check if examples contain images - some VLMs can be used for text-only tasks
has_images = "images" in examples[0]
if has_images:
images = [example["images"] for example in examples]
else:
images = None

if is_conversational(examples[0]): # conversational case
for example in examples:
prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"]))
num_images = len(example["images"]) if has_images else 0
prepare_multimodal_messages(example["prompt"] + example["completion"], num_images)
examples = [apply_chat_template(example, self.processor) for example in examples]

prompts = [example["prompt"] for example in examples]
completions = [example["completion"] for example in examples]

processed_prompts = self.processor(
images=images,
text=prompts,
padding=True,
padding_side="left",
return_tensors=self.return_tensors,
add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
)
# For text-only data with VLM models, don't pass images to the processor
processor_kwargs = {
"text": prompts,
"padding": True,
"padding_side": "left",
"return_tensors": self.return_tensors,
"add_special_tokens": False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
}
if has_images:
processor_kwargs["images"] = images

processed_prompts = self.processor(**processor_kwargs)
processed_completions = self.processor(
text=completions,
padding=True,
Expand Down