Skip to content
Merged

LoRA #33

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ logs/
*.pt
pyrightconfig.json
local
enora's junk/
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ pyglottolog==3.16.0
torch==2.6.0
tqdm==4.67.1
transformers==4.51.1
peft==0.17.1
19 changes: 17 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from transformers.models.auto.modeling_auto import AutoModelForPreTraining
from transformers.models.auto.tokenization_auto import AutoTokenizer
from peft import LoraConfig, PeftModel, get_peft_model, TaskType

import wandb
from src.config.config_to_dataclass import config_to_dataclass
Expand Down Expand Up @@ -53,6 +54,7 @@ def run(
models_folder = pathlib.Path(config.models_dir) / experiment_folder.stem
else:
models_folder = experiment_folder


if config.glottocode is not None:
# Create subfolders for each language if needed
Expand All @@ -67,10 +69,23 @@ def run(
distributed_parameters["device"]
)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
if config.adapter_dir:
model = PeftModel.from_pretrained(model, config.adapter_dir, is_trainable=True)
elif config.mode == "lora":
lora_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
r=config.lora_rank,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
)
model = get_peft_model(model, lora_config)

if distributed_parameters["distributed"]:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[distributed_parameters["local_rank"]]
)

if config.model_type == "seq2seq":
dataloaders, dataset = prepare_s2s_dataset.create_dataloaders(
tokenizer=tokenizer,
Expand All @@ -79,8 +94,8 @@ def run(
)
else:
raise NotImplementedError()

if config.mode in ["pretrain", "finetune"]:
if config.mode in ["pretrain", "finetune", "lora"]:
train(
model,
tokenizer=tokenizer,
Expand Down
14 changes: 13 additions & 1 deletion src/config/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass, field
from typing import Literal

TRAIN_MODE = Literal["pretrain", "predict", "finetune"]
TRAIN_MODE = Literal["pretrain", "predict", "finetune", "lora"]
SEGMENTATION_MODE = Literal["segmented", "unsegmented", "both"]
MODEL_TYPE = Literal["seq2seq", "decoder"]
CREATE_EXAMPLE_TYPE = Literal["none", "train-only", "train-test"]
Expand Down Expand Up @@ -84,6 +84,18 @@ class ExperimentConfig:
resume_from_checkpoint_id: str | None = None
"""WandB ID (and checkpoint ID) for checkpoint to resume training from."""

lora_rank: int = 8
"""Rank for LoRa"""

lora_dropout: float = 0.2
"""Dropout for LoRa"""

lora_alpha: int = 8
"""Alpha for LoRa"""

adapter_dir: str | None = None
"""LoRA adapter directory. If specified, will add adapter layer to pretrained model (for inference)"""

# ============================
# Generation
# ============================
Expand Down
13 changes: 8 additions & 5 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import tqdm
from torch.utils.data import DataLoader, DistributedSampler

from peft import set_peft_model_state_dict
import wandb
from src.config.experiment_config import ExperimentConfig
from src.distributed import DistributedParameters
Expand Down Expand Up @@ -60,11 +60,14 @@ def train(
logger.info(f"Loading from checkpoint {config.resume_from_checkpoint_id}.")
checkpoint = torch.load(
models_folder / f"{config.resume_from_checkpoint_id}.checkpoint.pt",
weights_only=True,
map_location=device,
)
(
model.module if distributed_parameters["distributed"] else model
).load_state_dict(checkpoint["model_state_dict"])
model_to_load = model.module if distributed_parameters["distributed"] else model
if config.mode == "lora":
set_peft_model_state_dict(model_to_load, checkpoint["model_state_dict"])

else:
model_to_load.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]

Expand Down