Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch
from transformers import LlamaConfig

from speculators.train.eagle3.core import Eagle3DraftModel, Eagle3VerifierLMHead
from speculators.train.data import Eagle3SampleFileDataset, create_collate_fn
from speculators.train.distributed_batch_sampler import (
MultipackDistributedBatchSamplerV2,
)
from torch.utils.data import DataLoader

from speculators.train.utils import maybe_setup_distributed, maybe_destroy_distributed
from speculators.train.trainer import Trainer
from speculators.train.logger import setup_metric_logger, setup_root_logger


local_rank, world_size, rank, is_distributed = maybe_setup_distributed()


DEVICE = torch.device(local_rank)
EPOCHS = 10
draft_vocab_size = 5000
total_seq_len = 5120
datapath = "./data"
verifier_model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct"


# TEMP MODEL SETUP
llama_config = LlamaConfig.from_pretrained(verifier_model_name_or_path)
llama_config._attn_implementation = "simple_flex_attention"
hidden_size = llama_config.hidden_size
verifier_vocab_size = llama_config.vocab_size

# d2t_vocab = torch.zeros(draft_vocab_size, dtype=torch.long).to(DEVICE)
# t2d_vocab = (
# torch.cat(
# [
# torch.ones(draft_vocab_size),
# torch.zeros(llama_config.vocab_size - draft_vocab_size),
# ]
# )
# .to(torch.bool)
# .to(DEVICE)
# )
d2t_vocab = torch.load("d2t.npy").to(DEVICE)
t2d_vocab = torch.load("t2d.npy").to(DEVICE)

setup_metric_logger(loggers="trackio", run_name=None, output_dir="./logs")
setup_root_logger()
# END TEMP MODEL SETUP

draft_model = Eagle3DraftModel(
verifier_model_name_or_path=verifier_model_name_or_path,
hidden_size=hidden_size,
t2d_vocab=t2d_vocab,
d2t_vocab=d2t_vocab,
decoder_layer_config=llama_config,
verifier_vocab_size=verifier_vocab_size,
verifier_pad_token_id=None,
num_layers=1,
ttt_steps=3,
)

verifier_lm_head = Eagle3VerifierLMHead(
hidden_size=hidden_size, draft_vocab_size=draft_vocab_size
)
verifier_lm_head.load_verifier_lm_head(verifier_model_name_or_path, t2d_vocab)

dataset = Eagle3SampleFileDataset(datapath=datapath, max_len=total_seq_len)
batch_sampler = MultipackDistributedBatchSamplerV2(
batch_max_length=total_seq_len,
lengths=dataset.approx_lengths(),
num_replicas=world_size,
rank=local_rank,
)
train_loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=16,
pin_memory=True,
collate_fn=create_collate_fn(total_seq_len),
)


# todo: make config better
config = {
"num_epochs": EPOCHS,
"save_path": "./checkpoints",
"lr": 1e-5,
"total_seq_len": total_seq_len,
"datapath": "./data",
"resume_from_checkpoint": True,
}


trainer = Trainer(
draft_model,
verifier_lm_head,
config,
train_loader,
None,
is_distributed,
local_rank,
world_size,
)
trainer.run_training()

maybe_destroy_distributed()


# RUN WITH:
# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=4 scripts/train.py
149 changes: 149 additions & 0 deletions src/speculators/train/checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from abc import abstractmethod
import os
from pathlib import Path
import torch
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)

import torch.distributed as dist


class BaseCheckpointer:
"""Helper class to save and load checkpoints.

Checkpoint file structure:
../path/
0/ # epoch number
model_state_dict.pt
optimizer_state_dict.pt
1/
model_state_dict.pt
optimizer_state_dict.pt
...
"""

def __init__(self, path: Path | str, try_load_last_checkpoint: bool = True):
self.path = Path(path)
if try_load_last_checkpoint:
self.previous_epoch: int = self._get_previous_epoch()
else:
self.previous_epoch: int = -1

@abstractmethod
def load_model_state_dict(self, model: torch.nn.Module):
raise NotImplementedError

@abstractmethod
def load_optimizer_state_dict(
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer
):
raise NotImplementedError

@abstractmethod
def save_checkpoint(
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int
):
raise NotImplementedError

def _get_previous_epoch(self) -> int:
if not self.path.exists():
return -1
last_checkpoint_num = -1
for d in self.path.iterdir():
if d.is_dir():
try:
last_checkpoint_num = max(last_checkpoint_num, int(d.name))
except ValueError:
continue
return last_checkpoint_num

def model_path(self, epoch: int):
model_fname = "model_state_dict.pt"
return self.path / str(epoch) / model_fname

def optimizer_path(self, epoch: int):
optimizer_fname = "optimizer_state_dict.pt"
return self.path / str(epoch) / optimizer_fname


class SingleGPUCheckpointer(BaseCheckpointer):
def load_model_state_dict(self, model: torch.nn.Module):
full_state_dict = torch.load(
self.model_path(self.previous_epoch),
weights_only=True,
map_location="cuda:0", # todo: make this configurable
)
model.load_state_dict(full_state_dict)

def load_optimizer_state_dict(
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer
):
full_state_dict = torch.load(
self.optimizer_path(self.previous_epoch),
weights_only=True,
map_location="cuda:0", # todo: make this configurable
)
optimizer.load_state_dict(full_state_dict)

def save_checkpoint(
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int
):
os.makedirs(self.path / str(epoch), exist_ok=True)
torch.save(model.state_dict(), self.model_path(epoch))
torch.save(optimizer.state_dict(), self.optimizer_path(epoch))


class DistributedCheckpointer(BaseCheckpointer):
def load_model_state_dict(self, model: torch.nn.Module):
full_state_dict = torch.load(
self.model_path(self.previous_epoch),
mmap=True,
weights_only=True,
map_location="cpu",
)
set_model_state_dict(
model,
full_state_dict,
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
)
dist.barrier()

def load_optimizer_state_dict(self, model, optimizer: torch.optim.Optimizer):
full_state_dict = torch.load(
self.optimizer_path(self.previous_epoch),
mmap=True,
weights_only=True,
map_location="cpu",
)
set_optimizer_state_dict(
model,
optimizer,
full_state_dict,
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
)
dist.barrier()

def save_checkpoint(
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int
):
model_state_dict = get_model_state_dict(
model, options=StateDictOptions(full_state_dict=True, cpu_offload=True)
)
optimizer_state_dict = get_optimizer_state_dict(
model,
optimizer,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)

if dist.get_rank() == 0:
# Only rank 0 saves the checkpoint
os.makedirs(self.path / str(epoch), exist_ok=True)
torch.save(model_state_dict, self.model_path(epoch))
torch.save(optimizer_state_dict, self.optimizer_path(epoch))

dist.barrier()
Loading
Loading