Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,31 @@

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from transformers import AutoProcessor, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BlipForConditionalGeneration, Gemma3ForConditionalGeneration

from config import Configuration
from utils import test_collate_function, visualize_bounding_boxes
import argparse

os.makedirs("outputs", exist_ok=True)

model_class_map = [
(lambda name: "gemma" in name, Gemma3ForConditionalGeneration),
(lambda name: "blip" in name, BlipForConditionalGeneration),
(lambda name: "kimi" in name, AutoModelForCausalLM),
]

def parse_args():
parser = argparse.ArgumentParser(description="Fine Tune Gemma3 for Object Detection")
parser.add_argument("--model", type=str, help="Model checkpoint identifier")
return parser.parse_args()

def get_model_class(model_name):
model_name = model_name.lower()
for condition, model_class in model_class_map:
if condition(model_name):
return model_class
return AutoModelForSeq2SeqLM

def get_dataloader(processor):
test_dataset = load_dataset(cfg.dataset_id, split="test")
Expand All @@ -21,15 +39,20 @@ def get_dataloader(processor):
)
return test_dataloader


if __name__ == "__main__":
args = parse_args()
cfg = Configuration()
if args.model:
cfg.model_id = args.model

processor = AutoProcessor.from_pretrained(cfg.checkpoint_id)
model = Gemma3ForConditionalGeneration.from_pretrained(
model_class = get_model_class(cfg.model_id)
model = model_class.from_pretrained(
cfg.checkpoint_id,
torch_dtype=cfg.dtype,
device_map="cpu",
)
)

model.eval()
model.to(cfg.device)

Expand Down
47 changes: 39 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from transformers import AutoProcessor, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BlipForConditionalGeneration, Gemma3ForConditionalGeneration

from config import Configuration
from utils import train_collate_function

import albumentations as A
import argparse

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
Expand All @@ -24,6 +25,23 @@
A.ColorJitter(p=0.2),
], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids'], filter_invalid_bboxes=True))

model_class_map = [
(lambda name: "gemma" in name, Gemma3ForConditionalGeneration),
(lambda name: "blip" in name, BlipForConditionalGeneration),
(lambda name: "kimi" in name, AutoModelForCausalLM),
]

def parse_args():
parser = argparse.ArgumentParser(description="Fine Tune Gemma3 for Object Detection")
parser.add_argument("--model", type=str, help="Model checkpoint identifier")
return parser.parse_args()

def get_model_class(model_name):
model_name = model_name.lower()
for condition, model_class in model_class_map:
if condition(model_name):
return model_class
return AutoModelForSeq2SeqLM

def get_dataloader(processor):
logger.info("Fetching the dataset")
Expand Down Expand Up @@ -61,17 +79,30 @@ def train_model(model, optimizer, cfg, train_dataloader):


if __name__ == "__main__":
args = parse_args()
cfg = Configuration()
if args.model:
cfg.model_id = args.model
processor = AutoProcessor.from_pretrained(cfg.model_id)
model_class = get_model_class(cfg.model_id)
train_dataloader = get_dataloader(processor)

logger.info("Getting model & turning only attention parameters to trainable")
model = Gemma3ForConditionalGeneration.from_pretrained(
cfg.model_id,
torch_dtype=cfg.dtype,
device_map="cpu",
attn_implementation="eager",
)

if "gemma" in cfg.model_id.lower():
model = model_class.from_pretrained(
cfg.model_id,
torch_dtype=cfg.dtype,
device_map="cpu",
attn_implementation="eager",
)
else:
model = model_class.from_pretrained(
cfg.model_id,
torch_dtype=cfg.dtype,
device_map="cpu",
)

for name, param in model.named_parameters():
if "attn" in name:
param.requires_grad = True
Expand All @@ -86,7 +117,7 @@ def train_model(model, optimizer, cfg, train_dataloader):
optimizer = torch.optim.AdamW(params_to_train, lr=cfg.learning_rate)

wandb.init(
project=cfg.project_name,
project=cfg.project_name if hasattr(cfg, "project_name") else None,
name=cfg.run_name if hasattr(cfg, "run_name") else None,
config=vars(cfg),
)
Expand Down