Skip to content

Conversation

@Vidit-Ostwal
Copy link
Contributor

@Vidit-Ostwal Vidit-Ostwal commented May 26, 2025

This PR add support to train with QLoRA.

@sergiopaniego
Copy link
Collaborator

Nice! Feel free to ping me once it's ready for review!

@Vidit-Ostwal
Copy link
Contributor Author

Hi @sergiopaniego, I think this one is ready for review.
Do let me know if you require any additional changes to be done.
Thanks

@Vidit-Ostwal Vidit-Ostwal changed the title Add QLoRA support file WIP Add QLoRA support file May 26, 2025
Copy link
Collaborator

@sergiopaniego sergiopaniego left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

Can we see some training results 😄?

@aminejebbar
Copy link

Using batch.to(model.device) is not supposed to be problematic given that batch comes from a DataLoader with a custom collate_fn, which means it is most likely a dictionary or a complex data structure containing multiple tensors rather than a simple tensor. The .to(device) method only exists for PyTorch tensors, not for dictionaries. Wouldn't it be better to create a helper function to handle this and avoid errors? Please correct me if I'm wrong.

@Vidit-Ostwal
Copy link
Contributor Author

Thanks for the PR!

Can we see some training results 😄?

Yes, working on this currently,
I don't have the GPU of right size for this, trying to incorporate accelerate library into this.

@Vidit-Ostwal
Copy link
Contributor Author

Using batch.to(model.device) is not supposed to be problematic given that batch comes from a DataLoader with a custom collate_fn, which means it is most likely a dictionary or a complex data structure containing multiple tensors rather than a simple tensor. The .to(device) method only exists for PyTorch tensors, not for dictionaries. Wouldn't it be better to create a helper function to handle this and avoid errors? Please correct me if I'm wrong.

Hi @aminejebbar, I think you have answered this yourself only, if you check DataLoader is imported form PyTorch itself.
They do have the functionality to move to different device cpu or gpu.

@Vidit-Ostwal
Copy link
Contributor Author

Vidit-Ostwal commented Jun 1, 2025

Hey @sergiopaniego, great news
I was able to train the model for 1 epoch, with a batch size of 4.

Here is the model hub link: https://huggingface.co/ViditOstwal/SmolVLM-256M-Instruct-object-detection-epoch-1

Will add some inference soon.

Training time was approximately ~11 hours

@Vidit-Ostwal
Copy link
Contributor Author

Hi @sergiopaniego, I am trying to run the predict.py file, but I am constantly getting into a assertion error.

This is how I am running the inference

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
processor = AutoProcessor.from_pretrained("ViditOstwal/SmolVLM-256M-Instruct-object-detection-epoch-1")

model = AutoModelForVision2Seq.from_pretrained(
    "ViditOstwal/SmolVLM-256M-Instruct-object-detection-epoch-1",
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=cfg.dtype,  # or replace cfg.dtype with torch.bfloat16 or torch.float16
    trust_remote_code=True  # important if it's a custom repo
)

from transformers import Idefics3Processor

model.eval()
test_dataloader = get_dataloader(processor=processor)
sample, sample_images = next(iter(test_dataloader))


model_device = next(model.parameters()).device
print(f"Model's primary device: {model_device}")

for key, value in sample.items():
    if isinstance(value, torch.Tensor):
        sample[key] = value.to(model_device)
        print(f"Moved {key} to {sample[key].device}")

generation = model.generate(**sample, max_new_tokens=100)

This is the error which is being raised.

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_93/1572887551.py in <cell line: 0>()
----> 1 generation = model.generate(**sample, max_new_tokens=100)

/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    114     def decorate_context(*args, **kwargs):
    115         with ctx_factory():
--> 116             return func(*args, **kwargs)
    117 
    118     return decorate_context

/usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, **kwargs)
   2463 
   2464             # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2465             result = self._sample(
   2466                 input_ids,
   2467                 logits_processor=prepared_logits_processor,

/usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py in _sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3429 
   3430             if is_prefill:
-> 3431                 outputs = self(**model_inputs, return_dict=True)
   3432                 is_prefill = False
   3433             else:

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1737             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738         else:
-> 1739             return self._call_impl(*args, **kwargs)
   1740 
   1741     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1748                 or _global_backward_pre_hooks or _global_backward_hooks
   1749                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750             return forward_call(*args, **kwargs)
   1751 
   1752         result = None

/usr/local/lib/python3.11/dist-packages/transformers/models/idefics3/modeling_idefics3.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, labels, use_cache, output_attentions, output_hidden_states, cache_position, return_dict, logits_to_keep)
   1108 
   1109         # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1110         outputs = self.model(
   1111             input_ids=input_ids,
   1112             attention_mask=attention_mask,

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1737             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738         else:
-> 1739             return self._call_impl(*args, **kwargs)
   1740 
   1741     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1748                 or _global_backward_pre_hooks or _global_backward_hooks
   1749                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750             return forward_call(*args, **kwargs)
   1751 
   1752         result = None

/usr/local/lib/python3.11/dist-packages/transformers/models/idefics3/modeling_idefics3.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, use_cache, output_attentions, output_hidden_states, cache_position, return_dict)
    914 
    915             # Get sequence from the vision encoder
--> 916             image_hidden_states = self.vision_model(
    917                 pixel_values=pixel_values,
    918                 patch_attention_mask=patch_attention_mask,

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1737             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738         else:
-> 1739             return self._call_impl(*args, **kwargs)
   1740 
   1741     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1748                 or _global_backward_pre_hooks or _global_backward_hooks
   1749                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750             return forward_call(*args, **kwargs)
   1751 
   1752         result = None

/usr/local/lib/python3.11/dist-packages/transformers/models/idefics3/modeling_idefics3.py in forward(self, pixel_values, patch_attention_mask, output_attentions, output_hidden_states, return_dict)
    637             patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
    638 
--> 639         encoder_outputs = self.encoder(
    640             inputs_embeds=hidden_states,
    641             attention_mask=patch_attention_mask,

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1737             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738         else:
-> 1739             return self._call_impl(*args, **kwargs)
   1740 
   1741     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1748                 or _global_backward_pre_hooks or _global_backward_hooks
   1749                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750             return forward_call(*args, **kwargs)
   1751 
   1752         result = None

/usr/local/lib/python3.11/dist-packages/transformers/models/idefics3/modeling_idefics3.py in forward(self, inputs_embeds, attention_mask, output_attentions, output_hidden_states, return_dict)
    424                 )
    425             else:
--> 426                 layer_outputs = encoder_layer(
    427                     hidden_states,
    428                     attention_mask,

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1737             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738         else:
-> 1739             return self._call_impl(*args, **kwargs)
   1740 
   1741     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1748                 or _global_backward_pre_hooks or _global_backward_hooks
   1749                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750             return forward_call(*args, **kwargs)
   1751 
   1752         result = None

/usr/local/lib/python3.11/dist-packages/transformers/models/idefics3/modeling_idefics3.py in forward(self, hidden_states, attention_mask, output_attentions)
    336 
    337         hidden_states = self.layer_norm1(hidden_states)
--> 338         hidden_states, attn_weights = self.self_attn(
    339             hidden_states=hidden_states,
    340             attention_mask=attention_mask,

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1737             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738         else:
-> 1739             return self._call_impl(*args, **kwargs)
   1740 
   1741     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1748                 or _global_backward_pre_hooks or _global_backward_hooks
   1749                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750             return forward_call(*args, **kwargs)
   1751 
   1752         result = None

/usr/local/lib/python3.11/dist-packages/transformers/models/idefics3/modeling_idefics3.py in forward(self, hidden_states, attention_mask, output_attentions)
    241         batch_size, seq_length, embed_dim = hidden_states.shape
    242 
--> 243         queries = self.q_proj(hidden_states)
    244         keys = self.k_proj(hidden_states)
    245         values = self.v_proj(hidden_states)

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1737             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738         else:
-> 1739             return self._call_impl(*args, **kwargs)
   1740 
   1741     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1748                 or _global_backward_pre_hooks or _global_backward_hooks
   1749                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750             return forward_call(*args, **kwargs)
   1751 
   1752         result = None

/usr/local/lib/python3.11/dist-packages/bitsandbytes/nn/modules.py in forward(self, x)
    478 
    479     def forward(self, x: torch.Tensor):
--> 480         fix_4bit_weight_quant_state_from_module(self)
    481 
    482         # weights are cast automatically as Int8Params, but the bias has to be cast manually

/usr/local/lib/python3.11/dist-packages/bitsandbytes/nn/modules.py in fix_4bit_weight_quant_state_from_module(module)
    370     # the quant state got lost when the parameter got converted. This happens for example for fsdp
    371     # since we registered the module, we can recover the state here
--> 372     assert module.weight.shape[1] == 1
    373     if not isinstance(module.weight, Params4bit):
    374         module.weight = Params4bit(module.weight, quant_storage=module.quant_storage, bnb_quantized=True)

AssertionError: 

@Vidit-Ostwal
Copy link
Contributor Author

Also I get this warning called Some weights of the model checkpoint at ViditOstwal/SmolVLM-256M-Instruct-object-detection-epoch-1 were not used when initializing Idefics3ForConditionalGeneration but I have used the same configuration

@Vidit-Ostwal
Copy link
Contributor Author

Also I had one doubt,
so by mistake I didn't passed the right model variable and it fine-tuned over all the parameters. The training time for one epoch of the entire dataset ~11 hours

Now, I passed the Qlora configured model, still the training time was approximately the same.
This seems a bit non intuitive to me.

@Vidit-Ostwal
Copy link
Contributor Author

Hi @sergiopaniego, I was trying to inference the output of the QLORA model.
this is what I am getting after I am printing the decoded in the predict.py

['user\n\n\n\n\n\ndetect\n\nassistant\nThe plate in the image contains the license plate number, which is 29A51796. The plate is yellow with black borders, and it has a shiny surface. The plate has a small emblem on the front, which is the Toyota logo. The background of the plate is black, and it has a yellow stripe running along the top.',
 'user\n\n\n\n\n\ndetect\n\nassistant\nThe plate in the image contains the license plate number, which is 51F-74776. The plate is yellow and black, and it has a shiny surface. The plate has a grill with a grill plate, which is yellow and black, and it has a plate with a plate plate, which is yellow and black.',
 'user\n\n\n\n\n\ndetect\n\nassistant\nThe plate has the number "4129" on it.',
 'user\n\n\n\n\n\ndetect\n\nassistant\n(plate)\n\n(car)\n\n(license plate)\n\n(vehicle)\n\n(building)\n\n(sign)\n\n(sign)\n\n(sign)\n\n(sign)\n\n(sign)\n\n(sign)\n\n(sign)\n\n(sign)\n\n(sign)\n\n(sign)\n\n(sign)\n\n(sign)\n\n(sign)\n\n(sign)\n\n(sign)']

This is the test_collate function I am using, I have modified it a bit to make it compatible

def test_collate_function(batch_of_samples, processor, dtype):
    images = []
    prompts = []

    for sample in batch_of_samples:
        images.append([sample["image"]])  # still one image per sample

        # prompt must include <image> to match the number of images
        prompts.append(
            f"{processor.tokenizer.bos_token}user\n<image>\ndetect\n{processor.tokenizer.eos_token}\n{processor.tokenizer.bos_token}assistant"
        )

    batch = processor(images=images, text=prompts, return_tensors="pt", padding=True)
    batch["pixel_values"] = batch["pixel_values"].to(device=cfg.device, dtype=torch.bfloat16)
    return batch, images

@sergiopaniego
Copy link
Collaborator

Thanks for the contribution again 😄
Maybe you can bring the latest changes and adapt them to this configuration.
It should support both models (smolVLM + Gemma 3) if you've trained with smolVLM

@Vidit-Ostwal
Copy link
Contributor Author

Vidit-Ostwal commented Jun 25, 2025

Hi @sergiopaniego,
I have added the training with QLORA in the main.py file itself.

An additional parameter of --peft_with_qlora can be given in the CLI to train with the QLORA.
Let me know if this sounds good to you.
Otherwise can make the changes in the train_qlora.py file itself.

Copy link
Contributor Author

@Vidit-Ostwal Vidit-Ostwal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger.info("Stage 2: Fine-tuning embed_tokens + attn")
run_training_phase(model, processor, cfg, train_dataloader, train_keys=["embed_tokens", "attn"], phase_name="embed_attn")
    else:
logger.info("Single-stage: Fine-tuning attn only")
run_training_phase(model, processor, cfg, train_dataloader, train_keys=["attn"], phase_name="attn_only")

I just have a small doubt.
I think here the train_keys needs to be changed to q_proj, k_proj,v_proj, o_proj. if we are planning to train with QLORA, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants