diff --git a/README.md b/README.md index 2e5a193e..34a768ce 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ If you want to use `uv`: uv init --bare --python 3.12 uv sync --python 3.12 source .venv/bin/activate -uv add torch numpy torchvision pillow datasets huggingface-hub transformers wandb +uv add -r requirements.txt # Optional: for lmms-eval integration you have to install it from source, see section 'Evaluation with lmms-eval' ``` @@ -70,6 +70,7 @@ Dependencies: - `datasets` for the training datasets - `huggingface-hub` & `transformers` to load the pretrained backbones - `wandb` for logging +- `einops` for data processing ## Training @@ -109,6 +110,7 @@ Generation 5: This is a cat sitting on the ground, which is covered with a mat. nanoVLM now supports evaluation using the comprehensive [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) toolkit: +###### Note that lmms-eval is currently unavailable of M4 series! ```bash # Install lmms-eval (has to be from source) uv pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git diff --git a/models/config.py b/models/config.py index 08f74dbb..cc3c8f13 100644 --- a/models/config.py +++ b/models/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field - +import torch @dataclass class VLMConfig: @@ -46,9 +46,9 @@ class VLMConfig: "r4c1": "", "r4c2": "", "r4c3": "", "r4c4": ""}) vlm_load_backbone_weights: bool = True vlm_checkpoint_path: str = 'checkpoints' + vlm_snapshot_path: str = 'snapshots' hf_repo_name: str = 'nanoVLM' - @dataclass class TrainConfig: lr_mp: float = 0.00512 @@ -66,7 +66,9 @@ class TrainConfig: max_images_per_knapsack: int = 18 max_sample_length: int = 1024 compile: bool = False - resume_from_vlm_checkpoint: bool = False # Indicate if the training should be resumed from a checkpoint of the whole VLM or you want to start from scratch + compile_mode: str = 'default' + compile_dynamic: bool = False + resume_from_vlm_checkpoint: bool = False train_dataset_path: str = 'HuggingFaceM4/the_cauldron' train_dataset_name: tuple[str, ...] = ("all", ) wandb_entity: str = "HuggingFace" # Indicate the entity to log to in wandb @@ -74,4 +76,29 @@ class TrainConfig: use_lmms_eval: bool = True # Use lmms-eval for evaluation lmms_eval_tasks: str = 'mmstar,mmmu,ocrbench,textvqa' # Pass additional task as one string, seperated by commas without spaces (e.g. 'mmstar,mmmu,ocrbench') lmms_eval_limit: int = 2000 - lmms_eval_batch_size: int = 128 \ No newline at end of file + lmms_eval_batch_size: int = 128 + save_snapshot: bool = True + downcast_model: bool = False + + + +MiniVLMConfig = VLMConfig() +MiniVLMConfig.lm_hidden_dim = 576 +MiniVLMConfig.lm_inter_dim = 1536 +MiniVLMConfig.lm_n_heads = 9 +MiniVLMConfig.lm_n_kv_heads = 3 +MiniVLMConfig.lm_n_blocks = 30 +MiniVLMConfig.lm_attn_scaling = 1.0 +MiniVLMConfig.lm_max_length = 1024 +MiniVLMConfig.lm_model_type = 'HuggingFaceTB/SmolLM2-135M-Instruct' +MiniVLMConfig.lm_tokenizer = 'HuggingFaceTB/cosmo2-tokenizer' +MiniVLMConfig.vlm_snapshot_path = 'snapshots-mini' +MiniVLMConfig.hf_repo_name = 'nanoVLM-mini' + + +MiniTrainerConfig = TrainConfig() +MiniTrainerConfig.batch_size = 1 +MiniTrainerConfig.gradient_accumulation_steps = 4 +MiniTrainerConfig.train_dataset_name = ("ai2d", ) +MiniTrainerConfig.downcast_model = True +MiniTrainerConfig.use_lmms_eval = False # install error! diff --git a/models/language_model.py b/models/language_model.py index 61b58297..71a1214b 100644 --- a/models/language_model.py +++ b/models/language_model.py @@ -655,3 +655,4 @@ def from_pretrained(cls, cfg): print(f"Successfully loaded {cfg.lm_model_type} weights from safetensors. Model has {sum(p.numel() for p in model.parameters()):,} parameters.") return model + \ No newline at end of file diff --git a/models/utils.py b/models/utils.py index c03e90d7..3712aaa7 100644 --- a/models/utils.py +++ b/models/utils.py @@ -49,3 +49,9 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf') logits = logits.masked_fill(indices_to_remove, filter_value) return logits + + +def convert_dtypes(tensor: torch.Tensor, target_dtype: torch.dtype): + if tensor.dtype != target_dtype: + tensor = tensor.to(target_dtype) + return tensor \ No newline at end of file diff --git a/models/vision_language_model.py b/models/vision_language_model.py index f22de5cc..cc22b87c 100644 --- a/models/vision_language_model.py +++ b/models/vision_language_model.py @@ -5,7 +5,7 @@ from typing import Optional -from models.utils import top_k_top_p_filtering +from models.utils import top_k_top_p_filtering, convert_dtypes from models.vision_transformer import ViT from models.language_model import LanguageModel from models.modality_projector import ModalityProjector @@ -73,6 +73,9 @@ def forward(self, input_ids, images, attention_mask=None, targets=None): logits = self.decoder.head(logits) # Apply LM head # Loss is calculated over all tokens, but `targets` (labels) will have -100 for non-answer tokens. # No need to slice logits based on image embedding size here, as the target mask handles it. + if os.environ.get("ModelDowncast"): + logits = convert_dtypes(logits, torch.float32) # if we downcast its more stable to have the logits in fp32 + loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100) return logits, loss diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..9c01323b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +torch +numpy +torchvision +pillow +datasets +huggingface-hub +transformers +wandb +einops \ No newline at end of file diff --git a/train.py b/train.py index 17ce4018..2a18681a 100644 --- a/train.py +++ b/train.py @@ -17,6 +17,8 @@ torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) +if torch.mps.is_available(): + torch.mps.seed() from data.collators import VQACollator from data.datasets import VQADataset @@ -30,6 +32,8 @@ #Otherwise, the tokenizer will throw a warning import os os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["expandable_segments"] = "true" # added just to be here, havent tested without it + # Fix for "Decompressed data too large" error with certain PNGs import PIL.PngImagePlugin @@ -40,6 +44,9 @@ def seed_worker(worker_id): numpy.random.seed(worker_seed) random.seed(worker_seed) +def get_num_workers(): + return min(8, os.cpu_count()) + def init_dist(): dist.init_process_group(backend='nccl') torch.cuda.set_device(dist.get_rank()) @@ -67,6 +74,25 @@ def dist_gather(o): def wrap_model(model): return DistributedDataParallel(model, device_ids=[dist.get_rank()]) +def downcast_model(model, dtype: torch.dtype): + os.environ["ModelDowncast"] = "1" + return model.to(dtype) + +def check_bf16_and_compile_availability(device): + """checks if bfloat 16 is available and torch.compile is available""" + bf16_available = False + torch_compile_available = False + if device.type == "mps": + bf16_available = True + torch_compile_available = False # its not supported yet + + elif device == "cuda": + bf16_available = torch.cuda.is_bfloat16_supported() + major, minor = torch.cuda.get_device_capability() + torch_compile_available = hasattr(torch, 'compile') and major >= 7 and minor >= 5 + + return bf16_available, torch_compile_available + def get_run_name(train_cfg, vlm_cfg): dataset_size = "full_ds" if train_cfg.data_cutoff_idx is None else f"{train_cfg.data_cutoff_idx}samples" batch_size = f"bs{int(train_cfg.batch_size*get_world_size()*train_cfg.gradient_accumulation_steps)}" @@ -134,13 +160,16 @@ def get_dataloaders(train_cfg, vlm_cfg): g.manual_seed(0) # Create dataloaders + num_workers = get_num_workers() + if is_master(): + print(f"NanoVLM will use {num_workers} workers for dataloading") train_loader = DataLoader( train_dataset, batch_size=train_cfg.batch_size, # =per device BS in DDP collate_fn=vqa_collator, - num_workers=8, - pin_memory=True, + num_workers=num_workers, + pin_memory=not torch.mps.is_available(), drop_last=True, worker_init_fn=seed_worker, generator=g, @@ -158,8 +187,8 @@ def get_dataloaders(train_cfg, vlm_cfg): batch_size=train_cfg.batch_size, sampler=val_sampler, collate_fn=vqa_collator, - num_workers=8, - pin_memory=True, + num_workers=num_workers, + pin_memory=not torch.mps.is_available(), drop_last=True, worker_init_fn=seed_worker, generator=g, @@ -184,6 +213,27 @@ def get_lr(it, max_lr, max_steps): coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0 return min_lr + coeff * (max_lr - min_lr) +def save_snapshot(save_directory, optimizer, current_step, current_epoch, vlm_cfg, train_cfg, dtype): + """Saves optimizer state, lr, step, epoch, config""" + if is_master(): + snapshot_path = os.path.join(save_directory, "snapshot.pt") + os.makedirs(save_directory, exist_ok=True) + + # Save optimizer state and metadata + snapshot = { + "optimizer_state": optimizer.state_dict(), + "optimizer_params": optimizer.param_groups, + "step": current_step, + "epoch": current_epoch, + "vlm_cfg": asdict(vlm_cfg), + "train_cfg": asdict(train_cfg), + "dtype": "bfloat16" if dtype == torch.bfloat16 else "float16" + } + # Save snapshot to a .pt file + torch.save(snapshot, snapshot_path) + print(f"Snapshot saved to {snapshot_path}") + + def train(train_cfg, vlm_cfg): train_loader, val_loader = get_dataloaders(train_cfg, vlm_cfg) tokenizer = get_tokenizer(vlm_cfg.lm_tokenizer, vlm_cfg.vlm_extra_tokens, vlm_cfg.lm_chat_template) @@ -209,7 +259,7 @@ def train(train_cfg, vlm_cfg): model = VisionLanguageModel.from_pretrained(vlm_cfg.vlm_checkpoint_path) else: model = VisionLanguageModel(vlm_cfg, load_backbone=vlm_cfg.vlm_load_backbone_weights) - + if is_master(): print(f"nanoVLM initialized with {sum(p.numel() for p in model.parameters()):,} parameters") print(f"Training summary{' (global)' if is_dist() else ''}: {len(train_loader.dataset)} samples, {int(len(train_loader)*get_world_size())} batches/epoch, batch size {int(train_cfg.batch_size*get_world_size()*train_cfg.gradient_accumulation_steps)}{', training on ' + str(get_world_size()) + ' GPUs' if is_dist() else ''}") @@ -229,7 +279,7 @@ def train(train_cfg, vlm_cfg): else: for p in list(model.decoder.parameters()) + list(model.vision_encoder.parameters()): p.requires_grad = False - + optimizer = optim.AdamW(param_groups) all_params = [p for group in optimizer.param_groups for p in group['params']] @@ -243,10 +293,25 @@ def train(train_cfg, vlm_cfg): torch.mps.empty_cache() print(f"Using device: {device}") + bf16_available, torch_compile_available = check_bf16_and_compile_availability(device) + Dtype = torch.bfloat16 if device.type in ['cuda', 'cpu', 'mps'] and bf16_available else torch.float16 + assert isinstance(Dtype, torch.dtype), f"Dtype is {Dtype}, expected torch.dtype" + + if is_master(): + print(f"Device compatability \nBFloat 16 is: {bf16_available}, torch.compile is: {torch_compile_available}") + print(f"NanoVLM will use {Dtype}!") + + if train_cfg.downcast_model: + model = downcast_model(model, Dtype) + if is_master(): + print(f"Downcasting model to {Dtype}") + model.to(device) - - if train_cfg.compile: - model = torch.compile(model) + if train_cfg.compile and torch_compile_available: + model = torch.compile(model, mode=train_cfg.compile_mode, dynamic=train_cfg.compile_dynamic) + if is_master(): + print(f'nanoVLM compiled with {train_cfg.compile_mode} mode, dynamic is {train_cfg.compile_dynamic}') + if is_dist(): model = wrap_model(model) @@ -295,7 +360,7 @@ def train(train_cfg, vlm_cfg): fw_bw_start = time.time() autocast_context = torch.autocast( device_type=device.type, - dtype=torch.bfloat16 if device.type in ['cuda', 'cpu'] else torch.float16 + dtype=Dtype ) with autocast_context: with context: @@ -348,6 +413,9 @@ def train(train_cfg, vlm_cfg): model.eval() if device == "cuda": torch.cuda.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() + with torch.no_grad(): total_val_loss = 0 for batch in val_loader: @@ -362,12 +430,15 @@ def train(train_cfg, vlm_cfg): total_val_loss += loss.item() avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0 avg_val_loss = mean(dist_gather(avg_val_loss)) if is_dist() else avg_val_loss - if avg_val_loss < best_val_loss: + if avg_val_loss < best_val_loss: # add min_improvment , snapshot + best_val_loss = avg_val_loss if is_master(): save_model = model.module if is_dist() else model # unwrap the model for saving if DDP save_model.save_pretrained(save_directory=os.path.join(vlm_cfg.vlm_checkpoint_path, run_name)) + save_snapshot(save_directory=vlm_cfg.vlm_snapshot_path, optimizer=optimizer, current_step=global_step, current_epoch=epoch, vlm_cfg=vlm_cfg, train_cfg=train_cfg, dtype=Dtype) + lmms_results = {} if train_cfg.use_lmms_eval: from evaluation import cli_evaluate @@ -500,15 +571,31 @@ def main(): parser.add_argument('--lr_mp', type=float, help='Learning rate for the mapping network') parser.add_argument('--lr_backbones', type=float, help='Learning rate for the backbones') parser.add_argument('--vlm_checkpoint_path', type=str, help='Path to the VLM checkpoint for loading or saving') - parser.add_argument('--compile', type=bool, help='Use torch.compile to optimize the model') + #https://docs.pytorch.org/docs/stable/generated/torch.compile.html + parser.add_argument('--compile', type=bool, default=False, help='Use torch.compile to optimize the model') + parser.add_argument('--compile_mode', type=str, help='Use torch.compile mode to optimize the model', choices=["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]) parser.add_argument('--log_wandb', type=bool, help='Log to wandb') parser.add_argument('--resume_from_vlm_checkpoint', type=bool, default=False, help='Resume training from VLM checkpoint specified by vlm_checkpoint_path (or default if not provided)') parser.add_argument('--no_log_wandb', action='store_true', help='Do not log to wandb') + parser.add_argument('--batch_size', type=int, help='Batch size for training') + parser.add_argument('--gradient_accumulation_steps', type=int, help='Gradient accumulation steps for training') + parser.add_argument('--wandb_entity', type=str, help='Entity to log to in wandb') + parser.add_argument('--downcast_model', type=bool, default=False, help='Downcast the model to bfloat16') + parser.add_argument('--use_lmms_eval', type=bool, default=False, help='Use lmms-eval for evaluation') + parser.add_argument('--use_135m', type=bool, default=False, help='Use 135M model') args = parser.parse_args() - vlm_cfg = config.VLMConfig() - train_cfg = config.TrainConfig() + if args.use_135m: + vlm_cfg = config.MiniVLMConfig + train_cfg = config.MiniTrainerConfig + if is_master(): + print("Using MiniVLMConfig and MiniTrainerConfig") + else: + vlm_cfg = config.VLMConfig() + train_cfg = config.TrainConfig() + if is_master(): + print("Using VLMConfig and TrainerConfig") if args.lr_mp is not None: train_cfg.lr_mp = args.lr_mp @@ -520,12 +607,29 @@ def main(): train_cfg.compile = args.compile if args.no_log_wandb is True: train_cfg.log_wandb = False + + if args.wandb_entity is not None: + train_cfg.wandb_entity = args.wandb_entity + + if args.batch_size is not None: + train_cfg.batch_size = args.batch_size + if args.gradient_accumulation_steps is not None: + train_cfg.gradient_accumulation_steps = args.gradient_accumulation_steps + + if args.downcast_model is not None: + train_cfg.downcast_model = args.downcast_model if args.resume_from_vlm_checkpoint and args.vlm_checkpoint_path is not None: - train_cfg.resume_from_vlm_checkpoint = True - # When resuming a full VLM, we don't need to load individual backbone weights from original sources + train_cfg.resume_from_vlm_checkpoint = False #True vlm_cfg.vlm_load_backbone_weights = False + if train_cfg.use_lmms_eval: + try: + os.system("uv pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git") + except Exception as e: + print(f"Error installing lmms-eval: {e}, will NOT use it for now...") + train_cfg.use_lmms_eval = False + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: init_dist() @@ -541,4 +645,4 @@ def main(): destroy_dist() if __name__ == "__main__": - main() + main() \ No newline at end of file