Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ If you want to use `uv`:
uv init --bare --python 3.12
uv sync --python 3.12
source .venv/bin/activate
uv add torch numpy torchvision pillow datasets huggingface-hub transformers wandb
uv add -r requirements.txt
# Optional: for lmms-eval integration you have to install it from source, see section 'Evaluation with lmms-eval'
```

Expand All @@ -70,6 +70,7 @@ Dependencies:
- `datasets` for the training datasets
- `huggingface-hub` & `transformers` to load the pretrained backbones
- `wandb` for logging
- `einops` for data processing

## Training

Expand Down Expand Up @@ -109,6 +110,7 @@ Generation 5: This is a cat sitting on the ground, which is covered with a mat.

nanoVLM now supports evaluation using the comprehensive [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) toolkit:

###### Note that lmms-eval is currently unavailable of M4 series!
```bash
# Install lmms-eval (has to be from source)
uv pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git
Expand Down
35 changes: 31 additions & 4 deletions models/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field

import torch

@dataclass
class VLMConfig:
Expand Down Expand Up @@ -46,9 +46,9 @@ class VLMConfig:
"r4c1": "<row_4_col_1>", "r4c2": "<row_4_col_2>", "r4c3": "<row_4_col_3>", "r4c4": "<row_4_col_4>"})
vlm_load_backbone_weights: bool = True
vlm_checkpoint_path: str = 'checkpoints'
vlm_snapshot_path: str = 'snapshots'
hf_repo_name: str = 'nanoVLM'


@dataclass
class TrainConfig:
lr_mp: float = 0.00512
Expand All @@ -66,12 +66,39 @@ class TrainConfig:
max_images_per_knapsack: int = 18
max_sample_length: int = 1024
compile: bool = False
resume_from_vlm_checkpoint: bool = False # Indicate if the training should be resumed from a checkpoint of the whole VLM or you want to start from scratch
compile_mode: str = 'default'
compile_dynamic: bool = False
resume_from_vlm_checkpoint: bool = False
train_dataset_path: str = 'HuggingFaceM4/the_cauldron'
train_dataset_name: tuple[str, ...] = ("all", )
wandb_entity: str = "HuggingFace" # Indicate the entity to log to in wandb
log_wandb: bool = True
use_lmms_eval: bool = True # Use lmms-eval for evaluation
lmms_eval_tasks: str = 'mmstar,mmmu,ocrbench,textvqa' # Pass additional task as one string, seperated by commas without spaces (e.g. 'mmstar,mmmu,ocrbench')
lmms_eval_limit: int = 2000
lmms_eval_batch_size: int = 128
lmms_eval_batch_size: int = 128
save_snapshot: bool = True
downcast_model: bool = False



MiniVLMConfig = VLMConfig()
MiniVLMConfig.lm_hidden_dim = 576
MiniVLMConfig.lm_inter_dim = 1536
MiniVLMConfig.lm_n_heads = 9
MiniVLMConfig.lm_n_kv_heads = 3
MiniVLMConfig.lm_n_blocks = 30
MiniVLMConfig.lm_attn_scaling = 1.0
MiniVLMConfig.lm_max_length = 1024
MiniVLMConfig.lm_model_type = 'HuggingFaceTB/SmolLM2-135M-Instruct'
MiniVLMConfig.lm_tokenizer = 'HuggingFaceTB/cosmo2-tokenizer'
MiniVLMConfig.vlm_snapshot_path = 'snapshots-mini'
MiniVLMConfig.hf_repo_name = 'nanoVLM-mini'


MiniTrainerConfig = TrainConfig()
MiniTrainerConfig.batch_size = 1
MiniTrainerConfig.gradient_accumulation_steps = 4
MiniTrainerConfig.train_dataset_name = ("ai2d", )
MiniTrainerConfig.downcast_model = True
MiniTrainerConfig.use_lmms_eval = False # install error!
1 change: 1 addition & 0 deletions models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,3 +655,4 @@ def from_pretrained(cls, cfg):

print(f"Successfully loaded {cfg.lm_model_type} weights from safetensors. Model has {sum(p.numel() for p in model.parameters()):,} parameters.")
return model

6 changes: 6 additions & 0 deletions models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,9 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')
logits = logits.masked_fill(indices_to_remove, filter_value)

return logits


def convert_dtypes(tensor: torch.Tensor, target_dtype: torch.dtype):
if tensor.dtype != target_dtype:
tensor = tensor.to(target_dtype)
return tensor
5 changes: 4 additions & 1 deletion models/vision_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional


from models.utils import top_k_top_p_filtering
from models.utils import top_k_top_p_filtering, convert_dtypes
from models.vision_transformer import ViT
from models.language_model import LanguageModel
from models.modality_projector import ModalityProjector
Expand Down Expand Up @@ -73,6 +73,9 @@ def forward(self, input_ids, images, attention_mask=None, targets=None):
logits = self.decoder.head(logits) # Apply LM head
# Loss is calculated over all tokens, but `targets` (labels) will have -100 for non-answer tokens.
# No need to slice logits based on image embedding size here, as the target mask handles it.
if os.environ.get("ModelDowncast"):
logits = convert_dtypes(logits, torch.float32) # if we downcast its more stable to have the logits in fp32

loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100)

return logits, loss
Expand Down
9 changes: 9 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
torch
numpy
torchvision
pillow
datasets
huggingface-hub
transformers
wandb
einops
Loading