diff --git a/.gitignore b/.gitignore index 7c728849..01041ab7 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,6 @@ results/ *.slurm *.arrow /shards/* -*.png \ No newline at end of file +*.png + +env/ \ No newline at end of file diff --git a/models/config.py b/models/config.py index 08f74dbb..2db5dce7 100644 --- a/models/config.py +++ b/models/config.py @@ -74,4 +74,30 @@ class TrainConfig: use_lmms_eval: bool = True # Use lmms-eval for evaluation lmms_eval_tasks: str = 'mmstar,mmmu,ocrbench,textvqa' # Pass additional task as one string, seperated by commas without spaces (e.g. 'mmstar,mmmu,ocrbench') lmms_eval_limit: int = 2000 - lmms_eval_batch_size: int = 128 \ No newline at end of file + lmms_eval_batch_size: int = 128 + + # LoRA Configuration + use_lora: bool = True + lora_rank: int = 16 + lora_alpha: int = 32 + lora_dropout: float = 0.1 + lora_target_modules: list[str] = field(default_factory=lambda: [ + # Vision Transformer + "vision_encoder.blocks.*.attn.qkv_proj", + "vision_encoder.blocks.*.attn.out_proj", + "vision_encoder.blocks.*.mlp.fc1", + "vision_encoder.blocks.*.mlp.fc2", + + # Language Model + "decoder.blocks.*.attn.q_proj", + "decoder.blocks.*.attn.k_proj", + "decoder.blocks.*.attn.v_proj", + "decoder.blocks.*.attn.out_proj", + "decoder.blocks.*.mlp.gate_proj", + "decoder.blocks.*.mlp.up_proj", + "decoder.blocks.*.mlp.down_proj", + "decoder.head", + + # Modality Projector + "MP.proj", + ]) \ No newline at end of file diff --git a/models/vision_language_model.py b/models/vision_language_model.py index f22de5cc..262500d7 100644 --- a/models/vision_language_model.py +++ b/models/vision_language_model.py @@ -33,6 +33,39 @@ def __init__(self, cfg: VLMConfig, load_backbone=True): self.load_backbone = load_backbone self.tokenizer = get_tokenizer(cfg.lm_tokenizer, cfg.vlm_extra_tokens, cfg.lm_chat_template) + # Add config attribute for PEFT compatibility + self.config = self._create_hf_compatible_config() + + def _create_hf_compatible_config(self): + """Create a minimal HuggingFace-compatible config for PEFT""" + + class HFCompatibleConfig: + """A config class that behaves like both an object and a dictionary for PEFT compatibility""" + def __init__(self, cfg): + self.model_type = "vision_language_model" # Custom model type + self.vocab_size = cfg.lm_vocab_size + self.hidden_size = cfg.lm_hidden_dim + self.num_hidden_layers = cfg.lm_n_blocks + self.num_attention_heads = cfg.lm_n_heads + self.intermediate_size = cfg.lm_inter_dim + self.max_position_embeddings = cfg.lm_max_position_embeddings + self.rms_norm_eps = cfg.lm_rms_eps + self.tie_word_embeddings = cfg.lm_tie_weights + + def get(self, key, default=None): + """Dictionary-like get method for PEFT compatibility""" + return getattr(self, key, default) + + def __getitem__(self, key): + """Dictionary-like access for PEFT compatibility""" + return getattr(self, key) + + def __contains__(self, key): + """Dictionary-like 'in' operator for PEFT compatibility""" + return hasattr(self, key) + + return HFCompatibleConfig(self.cfg) + def _replace_img_tokens_with_embd(self, input_ids, token_embd, image_embd): """ Replace every image-token placeholder in `input_ids` with the corresponding slice @@ -48,7 +81,15 @@ def _replace_img_tokens_with_embd(self, input_ids, token_embd, image_embd): return updated_token_embd - def forward(self, input_ids, images, attention_mask=None, targets=None): + + def forward(self, input_ids, images=None, attention_mask=None, labels=None, **kwargs): + # Handle different argument names - PEFT might pass 'labels' instead of 'targets' + targets = labels if labels is not None else kwargs.get('targets', None) + + # Handle cases where images might be passed through kwargs + if images is None: + images = kwargs.get('images', []) + if isinstance(images, list): if not images: # Handle cases with no images images = torch.empty(0, self.cfg.vit_channels, self.cfg.vit_image_size, self.cfg.vit_image_size, device=input_ids.device) @@ -76,6 +117,17 @@ def forward(self, input_ids, images, attention_mask=None, targets=None): loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100) return logits, loss + + def prepare_inputs_for_generation(self, input_ids, **kwargs): + """ + Prepare inputs for generation. Required by PEFT for CAUSAL_LM task type. + This is a minimal implementation that just returns the basic inputs. + """ + + return { + "input_ids": input_ids, + **kwargs + } @torch.inference_mode() def generate(self, input_ids, images, attention_mask=None, max_new_tokens=5, top_k=50, top_p=0.9, temperature=0.5, greedy=False): diff --git a/train.py b/train.py index 17ce4018..5db700fc 100644 --- a/train.py +++ b/train.py @@ -13,6 +13,7 @@ from torch.utils.data import DataLoader, DistributedSampler import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel +from peft import LoraConfig, get_peft_model, TaskType torch.manual_seed(0) if torch.cuda.is_available(): @@ -210,6 +211,24 @@ def train(train_cfg, vlm_cfg): else: model = VisionLanguageModel(vlm_cfg, load_backbone=vlm_cfg.vlm_load_backbone_weights) + # Apply LoRA if enabled + if train_cfg.use_lora: + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, # Now works with the added prepare_inputs_for_generation method + r=train_cfg.lora_rank, + lora_alpha=train_cfg.lora_alpha, + lora_dropout=train_cfg.lora_dropout, + target_modules=train_cfg.lora_target_modules, + bias="none", + ) + + model = get_peft_model(model, lora_config) + + if is_master(): + print("\n=== LoRA Configuration ===") + model.print_trainable_parameters() + print("===========================\n") + if is_master(): print(f"nanoVLM initialized with {sum(p.numel() for p in model.parameters()):,} parameters") print(f"Training summary{' (global)' if is_dist() else ''}: {len(train_loader.dataset)} samples, {int(len(train_loader)*get_world_size())} batches/epoch, batch size {int(train_cfg.batch_size*get_world_size()*train_cfg.gradient_accumulation_steps)}{', training on ' + str(get_world_size()) + ' GPUs' if is_dist() else ''}") @@ -299,7 +318,7 @@ def train(train_cfg, vlm_cfg): ) with autocast_context: with context: - _, loss = model(input_ids, images, attention_mask=attention_mask, targets=labels) + _, loss = model(input_ids=input_ids, images=images, attention_mask=attention_mask, labels=labels) if train_cfg.gradient_accumulation_steps > 1: loss = loss / train_cfg.gradient_accumulation_steps @@ -357,7 +376,7 @@ def train(train_cfg, vlm_cfg): attention_mask = batch["attention_mask"].to(device) with autocast_context: - _, loss = model(input_ids, images, attention_mask=attention_mask, targets=labels) + _, loss = model(input_ids=input_ids, images=images, attention_mask=attention_mask, labels=labels) total_val_loss += loss.item() avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0