1
+ import os
1
2
import torch
2
3
from transformers import LlamaConfig
3
4
4
5
from speculators .train .eagle3 .core import Eagle3DraftModel
5
6
from speculators .train .data import Eagle3SampleFileDataset , create_collate_fn
6
7
from torch .utils .data import DataLoader
7
8
9
+ from torch .nn .parallel import DistributedDataParallel as DDP
10
+ import torch .distributed as dist
8
11
9
- DEVICE = "cuda:0"
12
+ def maybe_setup_distributed ():
13
+ # Based off of https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html#initialize-ddp-with-torch-distributed-run-torchrun
14
+ if "LOCAL_RANK" not in os .environ :
15
+ # No distributed training
16
+ return 0 , 1 , 0 , False
17
+ local_rank = int (os .environ .get ("LOCAL_RANK" , 0 ))
18
+ world_size = int (os .environ .get ("WORLD_SIZE" , 1 ))
19
+ torch .accelerator .set_device_index (local_rank )
20
+ acc = torch .accelerator .current_accelerator ()
21
+ backend = torch .distributed .get_default_backend_for_device (acc )
22
+ dist .init_process_group (backend )
23
+ rank = dist .get_rank ()
24
+
25
+ print (f'Started DDP with local_rank={ local_rank } , world_size={ world_size } , rank={ rank } ' )
26
+ return local_rank , world_size , rank , True
27
+
28
+ local_rank , world_size , rank , is_distributed = maybe_setup_distributed ()
29
+
30
+
31
+ DEVICE = torch .device (local_rank )
10
32
EPOCHS = 10
11
33
draft_vocab_size = 5000
12
34
verifier_vocab_size = 151936
47
69
48
70
# draft_model.load_verifier_lm_head(verifier_model_name_or_path) # Doesn't work for Qwen2.5 VL, need better head loading method
49
71
72
+ if is_distributed :
73
+ draft_model = DDP (draft_model , device_ids = [local_rank ])
74
+ opt = torch .optim .Adam (draft_model .parameters (), lr = 1e-4 )
50
75
51
76
dataset = Eagle3SampleFileDataset (datapath = datapath , max_len = total_seq_len )
52
77
train_loader = DataLoader (
57
82
pin_memory = True ,
58
83
collate_fn = create_collate_fn (total_seq_len ),
59
84
)
60
- opt = torch .optim .Adam (draft_model .parameters (), lr = 1e-4 )
61
85
62
86
63
87
def train_epoch (
@@ -67,18 +91,35 @@ def train_epoch(
67
91
opt : torch .optim .Optimizer ,
68
92
epoch : int ,
69
93
local_rank : int ,
94
+ is_distributed : bool ,
70
95
):
71
96
model .train ()
72
97
73
98
for batch in train_loader :
74
99
batch = {k : v .to (local_rank ) if isinstance (v , torch .Tensor ) else v for k , v in batch .items ()}
75
100
76
101
_ , loss = model (** batch , use_off_policy_tokens = True )
77
- print (loss .item ())
78
102
opt .zero_grad ()
79
103
loss .backward ()
80
104
opt .step ()
81
105
106
+ loss = loss .detach ().clone ()
107
+ if is_distributed :
108
+ # Note: this is not needed for training, just for logging
109
+ dist .reduce (loss , dst = 0 , op = dist .ReduceOp .AVG )
110
+
111
+ if local_rank == 0 :
112
+ print (loss .item ())
113
+
114
+
82
115
83
116
for epoch in range (EPOCHS ):
84
- train_epoch (draft_model , train_loader , None , opt , epoch , DEVICE )
117
+ train_epoch (draft_model , train_loader , None , opt , epoch , local_rank , is_distributed )
118
+
119
+ if is_distributed :
120
+ dist .destroy_process_group ()
121
+ print (f'Destroyed DDP with local_rank={ local_rank } , world_size={ world_size } , rank={ rank } ' )
122
+
123
+
124
+ # RUN WITH:
125
+ # CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=4 src/speculators/train/training_loop.py
0 commit comments