1919import torch
2020import torch .distributed as dist
2121import torch .multiprocessing as mp
22- from distill_trainer import EagleTPTrainer
2322from eagle_utils import DataCollatorWithPadding , make_eagle_supervised_data_module
23+ from trainer .distill_trainer import EagleSGLTrainer , EagleTPTrainer
2424from transformers import AutoTokenizer
2525
26- # Hyperparameters for profiling
2726torch .manual_seed (0 )
2827
2928
@@ -34,10 +33,10 @@ def _setup_distributed(rank, args, backend="nccl"):
3433 os .environ ["LOCAL_RANK" ] = str (rank )
3534 # Initialize process group
3635 dist .init_process_group (backend , rank = rank , world_size = args .world_size )
37- if rank in args .student_ranks :
38- torch .cuda .set_device (args .student_devices [rank ])
36+ if rank in args .teacher_ranks :
37+ torch .cuda .set_device (args .teacher_devices [rank ])
3938 else :
40- torch .cuda .set_device (args .teacher_devices [rank - len (args .student_ranks )])
39+ torch .cuda .set_device (args .student_devices [rank - len (args .teacher_ranks )])
4140 print (
4241 f"Starting process rank={ rank } , device={ torch .cuda .current_device ()} , world_size={ args .world_size } "
4342 )
@@ -51,7 +50,10 @@ def train(rank, args):
5150 tokenizer = AutoTokenizer .from_pretrained (
5251 args .model_path , model_max_length = args .training_seq_len
5352 )
54- data_module = make_eagle_supervised_data_module (tokenizer , args , use_offline_training = False )
53+ args .use_offline_training = False
54+ args .vlm_processor = None
55+ args .offline_data_path = None
56+ data_module = make_eagle_supervised_data_module (tokenizer , args )
5557
5658 train_dataloader = torch .utils .data .DataLoader (
5759 data_module ["train_dataset" ],
@@ -61,42 +63,56 @@ def train(rank, args):
6163 collate_fn = DataCollatorWithPadding (max_length = args .training_seq_len ),
6264 drop_last = True ,
6365 )
64-
65- trainer = EagleTPTrainer (rank , args , tokenizer , train_dataloader )
66+ trainer_cls = {
67+ "sglang" : EagleSGLTrainer ,
68+ "hf" : EagleTPTrainer ,
69+ }[args .teacher_backend ]
70+ trainer = trainer_cls (rank , args , tokenizer , train_dataloader )
6671 trainer .train ()
67- trainer .save_pretrained (args .out_path )
72+ trainer .save (args .out_path )
6873
6974
7075def main ():
7176 parser = argparse .ArgumentParser (description = "Multi-GPU distributed two-stage forward example" )
77+
78+ # Training args
7279 parser .add_argument ("--model_path" , type = str , default = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" )
73- parser .add_argument ("--student_devices" , type = list , default = [0 , 1 , 2 , 3 ])
74- parser .add_argument ("--teacher_devices" , type = list , default = [4 , 5 , 6 , 7 ])
75- parser .add_argument (
76- "--data_path" , type = str , default = "data/magpie_llama3.2_1b_generated/data.cleaned.jsonl"
77- )
80+ parser .add_argument ("--data_path" , type = str , required = True , help = "Training dataset." )
7881 parser .add_argument ("--training_seq_len" , type = str , default = 1024 )
7982 parser .add_argument ("--eagle_config_path" , type = str , default = "eagle_config.json" )
80- parser .add_argument (
81- "--lazy_preprocess" , type = bool , default = True , help = "Whether to use lazy preprocessing."
82- )
8383 parser .add_argument ("--out_path" , type = str , default = "ckpts/fast-trained" )
8484 parser .add_argument ("--lr" , type = float , default = 1e-5 )
8585 parser .add_argument ("--epoch" , type = int , default = 1 )
86+ parser .add_argument ("--batch_size" , type = int , default = 8 , help = "Total bs across all ranks." )
87+
88+ # Trainer args
89+ parser .add_argument ("--teacher_backend" , type = str , choices = ["sglang" , "hf" ], default = "sglang" )
90+ parser .add_argument (
91+ "--teacher_ep_size" ,
92+ type = int ,
93+ default = 1 ,
94+ help = "Teacher EP size, only used for sglang backend." ,
95+ )
96+ parser .add_argument ("--teacher_devices" , type = list , default = [0 , 1 , 2 , 3 ])
97+ parser .add_argument ("--student_devices" , type = list , default = [4 , 5 , 6 , 7 ])
98+ parser .add_argument (
99+ "--lazy_preprocess" , type = bool , default = True , help = "Whether to use lazy preprocessing."
100+ )
101+ parser .add_argument ("--log_interval" , type = int , default = 50 )
102+ parser .add_argument ("--save_interval" , type = int , default = 20000 )
86103 parser .add_argument (
87- "--batch_size " , type = int , default = 4 , help = "Total batch size across all parallel ranks ."
104+ "--total_steps " , type = int , default = 60000 , help = "Total number of steps for debugging ."
88105 )
89106 parser .add_argument ("--master_port" , type = str , default = "12357" )
90107
91108 args = parser .parse_args ()
92109 # TODO: add sanity check for args
93110
94111 def set_ranks (args ):
95- # TODO(hg): This is for TP-DDP setting only. Add "no-parallel", "MP", "FSDP".
96112 args .world_size = len (args .teacher_devices ) + len (args .student_devices )
97- args .student_ranks = list (range (len (args .student_devices )))
98- args .teacher_ranks = list (
99- range (len (args .student_devices ), len (args .student_devices ) + len (args .teacher_devices ))
113+ args .teacher_ranks = list (range (len (args .teacher_devices )))
114+ args .student_ranks = list (
115+ range (len (args .teacher_devices ), len (args .teacher_devices ) + len (args .student_devices ))
100116 )
101117
102118 set_ranks (args )
0 commit comments