Skip to content

Commit 76aa8a1

Browse files
committed
Add Distributed Data Parallel (DDP)
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 08c8848 commit 76aa8a1

File tree

1 file changed

+45
-4
lines changed

1 file changed

+45
-4
lines changed

src/speculators/train/training_loop.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,34 @@
1+
import os
12
import torch
23
from transformers import LlamaConfig
34

45
from speculators.train.eagle3.core import Eagle3DraftModel
56
from speculators.train.data import Eagle3SampleFileDataset, create_collate_fn
67
from torch.utils.data import DataLoader
78

9+
from torch.nn.parallel import DistributedDataParallel as DDP
10+
import torch.distributed as dist
811

9-
DEVICE = "cuda:0"
12+
def maybe_setup_distributed():
13+
# Based off of https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html#initialize-ddp-with-torch-distributed-run-torchrun
14+
if "LOCAL_RANK" not in os.environ:
15+
# No distributed training
16+
return 0, 1, 0, False
17+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
18+
world_size = int(os.environ.get("WORLD_SIZE", 1))
19+
torch.accelerator.set_device_index(local_rank)
20+
acc = torch.accelerator.current_accelerator()
21+
backend = torch.distributed.get_default_backend_for_device(acc)
22+
dist.init_process_group(backend)
23+
rank = dist.get_rank()
24+
25+
print(f'Started DDP with local_rank={local_rank}, world_size={world_size}, rank={rank}')
26+
return local_rank, world_size, rank, True
27+
28+
local_rank, world_size, rank, is_distributed = maybe_setup_distributed()
29+
30+
31+
DEVICE = torch.device(local_rank)
1032
EPOCHS = 10
1133
draft_vocab_size = 5000
1234
verifier_vocab_size = 151936
@@ -47,6 +69,9 @@
4769

4870
# draft_model.load_verifier_lm_head(verifier_model_name_or_path) # Doesn't work for Qwen2.5 VL, need better head loading method
4971

72+
if is_distributed:
73+
draft_model = DDP(draft_model, device_ids=[local_rank])
74+
opt = torch.optim.Adam(draft_model.parameters(), lr=1e-4)
5075

5176
dataset = Eagle3SampleFileDataset(datapath=datapath, max_len=total_seq_len)
5277
train_loader = DataLoader(
@@ -57,7 +82,6 @@
5782
pin_memory=True,
5883
collate_fn=create_collate_fn(total_seq_len),
5984
)
60-
opt = torch.optim.Adam(draft_model.parameters(), lr=1e-4)
6185

6286

6387
def train_epoch(
@@ -67,18 +91,35 @@ def train_epoch(
6791
opt: torch.optim.Optimizer,
6892
epoch: int,
6993
local_rank: int,
94+
is_distributed: bool,
7095
):
7196
model.train()
7297

7398
for batch in train_loader:
7499
batch = {k: v.to(local_rank) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
75100

76101
_, loss = model(**batch, use_off_policy_tokens=True)
77-
print(loss.item())
78102
opt.zero_grad()
79103
loss.backward()
80104
opt.step()
81105

106+
loss = loss.detach().clone()
107+
if is_distributed:
108+
# Note: this is not needed for training, just for logging
109+
dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG)
110+
111+
if local_rank == 0:
112+
print(loss.item())
113+
114+
82115

83116
for epoch in range(EPOCHS):
84-
train_epoch(draft_model, train_loader, None, opt, epoch, DEVICE)
117+
train_epoch(draft_model, train_loader, None, opt, epoch, local_rank, is_distributed)
118+
119+
if is_distributed:
120+
dist.destroy_process_group()
121+
print(f'Destroyed DDP with local_rank={local_rank}, world_size={world_size}, rank={rank}')
122+
123+
124+
# RUN WITH:
125+
# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=4 src/speculators/train/training_loop.py

0 commit comments

Comments
 (0)