Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,6 +1374,36 @@ def test_train_vlm_gemma_3n(self):
continue
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")

@require_vision
def test_train_vlm_text_only_data(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")

# Initialize the trainer
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
args=training_args,
train_dataset=dataset,
)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n.startswith("model.visual"):
self.assertTrue(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated")
else:
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")

@require_peft
def test_prompt_tuning(self):
"""Test that SFT works with Prompt Tuning."""
Expand Down
41 changes: 35 additions & 6 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,14 @@ def __init__(
else:
self.completion_only_loss = args.completion_only_loss

if data_collator is None and not self._is_vlm:
self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
if self._is_vision_dataset and not self._is_vlm:
raise ValueError(
"The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
"model does not seem to be a vision-language model. Please check your model and dataset."
)

if data_collator is None and not self._is_vision_dataset:
# Get the pad token: if not provided, use the one from the processing class or the eos token
# if the processing class does not have a pad token.
pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
Expand All @@ -777,7 +784,7 @@ def __init__(
return_position_ids=use_flash_attention,
pad_to_multiple_of=args.pad_to_multiple_of,
)
elif data_collator is None and self._is_vlm:
elif data_collator is None and self._is_vision_dataset:
data_collator = DataCollatorForVisionLanguageModeling(
processor=processing_class,
max_length=args.max_length,
Expand Down Expand Up @@ -805,7 +812,9 @@ def __init__(
# Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where
# preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead.
skip_prepare_dataset = (
args.dataset_kwargs is not None and args.dataset_kwargs.get("skip_prepare_dataset", False) or self._is_vlm
args.dataset_kwargs is not None
and args.dataset_kwargs.get("skip_prepare_dataset", False)
or self._is_vision_dataset
)
if not skip_prepare_dataset:
if self.completion_only_loss and formatting_func:
Expand Down Expand Up @@ -959,22 +968,36 @@ def add_eos(example, eos_token):
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"

def tokenize(example, processing_class, dataset_text_field, assistant_only_loss):
def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_loss):
if "prompt" in example: # prompt-completion case
output = {}
if is_conversational(example):
if self._is_vlm:
prepare_multimodal_messages(example["prompt"], num_images=0)
prepare_multimodal_messages(example["completion"], num_images=0)
prompt_ids = processing_class.apply_chat_template(
example["prompt"],
tokenize=True,
tools=example.get("tools"),
**example.get("chat_template_kwargs", {}),
)
# Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
# even for single examples, while for LLMs it returns lists of ints.
prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids
prompt_completion_processed = processing_class.apply_chat_template(
example["prompt"] + example["completion"],
return_dict=True,
tokenize=True,
return_assistant_tokens_mask=assistant_only_loss,
tools=example.get("tools"),
**example.get("chat_template_kwargs", {}),
)
# Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
# even for single examples, while for LLMs it returns lists of ints.
prompt_completion_processed = {
k: v[0] if isinstance(v[0], list) else v
for k, v in prompt_completion_processed.items()
}
prompt_completion_ids = prompt_completion_processed["input_ids"]
if "assistant_masks" in prompt_completion_processed:
output["assistant_masks"] = prompt_completion_processed["assistant_masks"]
Expand All @@ -999,13 +1022,19 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss)

else: # language modeling case
if is_conversational(example):
if self._is_vlm:
prepare_multimodal_messages(example["messages"], num_images=0)
processed = processing_class.apply_chat_template(
example["messages"],
return_dict=True,
tokenize=True,
return_assistant_tokens_mask=assistant_only_loss,
tools=example.get("tools"),
**example.get("chat_template_kwargs", {}),
)
# Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
# even for single examples, while for LLMs it returns lists of ints.
processed = {k: v[0] if isinstance(v[0], list) else v for k, v in processed.items()}
if "assistant_masks" in processed and 1 not in processed["assistant_masks"]:
raise RuntimeError(
"You're using `assistant_only_loss=True`, but at least one example has no "
Expand All @@ -1020,7 +1049,7 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss)
return output

dataset = dataset.map(
tokenize,
tokenize_fn,
fn_kwargs={
"processing_class": processing_class,
"dataset_text_field": args.dataset_text_field,
Expand Down Expand Up @@ -1064,7 +1093,7 @@ def _set_signature_columns_if_needed(self):
# and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the
# dataset. So we need to override the default signature columns to include "completion_mask" as well.
if self._signature_columns is None:
if self._is_vlm:
if self._is_vision_dataset:
self._signature_columns = ["messages", "prompt", "completion", "images"]
else:
self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"]
Expand Down
Loading