Skip to content

Commit f5bb1ae

Browse files
committed
squash: new trainer with HF and SGL backend
Signed-off-by: h-guo18 <[email protected]>
1 parent 3aedb33 commit f5bb1ae

File tree

6 files changed

+388
-66
lines changed

6 files changed

+388
-66
lines changed

examples/speculative_decoding/ar_validate.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@
2626
mto.enable_huggingface_checkpointing()
2727

2828

29-
def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None):
29+
def validate_ar(
30+
model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None, disable_pbar=False
31+
):
3032
validator = HFARValidation(model, tokenizer)
3133
num_samples = min(num_samples, len(ds))
3234
ars = []
33-
for i in tqdm(range(num_samples), desc="Validating AR"):
35+
print("validating AR...")
36+
for i in tqdm(range(num_samples), disable=disable_pbar):
3437
prompt = ds[i]["prompt"][0]
3538
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
3639
# Apply chat template to the prompt, continuing with assistant response

examples/speculative_decoding/eagle_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def compute_loss(self, *args, **kwargs):
518518
kwargs.pop("num_items_in_batch", None)
519519
loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs)
520520
if hasattr(outputs, "train_acc"):
521-
self.state.training_accs.append(outputs.train_acc)
521+
self.state.training_accs.append([acc.item() for acc in outputs.train_acc])
522522
return loss
523523

524524

examples/speculative_decoding/train.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919
import torch
2020
import torch.distributed as dist
2121
import torch.multiprocessing as mp
22-
from distill_trainer import EagleTPTrainer
2322
from eagle_utils import DataCollatorWithPadding, make_eagle_supervised_data_module
23+
from trainer.distill_trainer import EagleSGLTrainer, EagleTPTrainer
2424
from transformers import AutoTokenizer
2525

26-
# Hyperparameters for profiling
2726
torch.manual_seed(0)
2827

2928

@@ -34,10 +33,10 @@ def _setup_distributed(rank, args, backend="nccl"):
3433
os.environ["LOCAL_RANK"] = str(rank)
3534
# Initialize process group
3635
dist.init_process_group(backend, rank=rank, world_size=args.world_size)
37-
if rank in args.student_ranks:
38-
torch.cuda.set_device(args.student_devices[rank])
36+
if rank in args.teacher_ranks:
37+
torch.cuda.set_device(args.teacher_devices[rank])
3938
else:
40-
torch.cuda.set_device(args.teacher_devices[rank - len(args.student_ranks)])
39+
torch.cuda.set_device(args.student_devices[rank - len(args.teacher_ranks)])
4140
print(
4241
f"Starting process rank={rank}, device={torch.cuda.current_device()}, world_size={args.world_size}"
4342
)
@@ -51,7 +50,10 @@ def train(rank, args):
5150
tokenizer = AutoTokenizer.from_pretrained(
5251
args.model_path, model_max_length=args.training_seq_len
5352
)
54-
data_module = make_eagle_supervised_data_module(tokenizer, args, use_offline_training=False)
53+
args.use_offline_training = False
54+
args.vlm_processor = None
55+
args.offline_data_path = None
56+
data_module = make_eagle_supervised_data_module(tokenizer, args)
5557

5658
train_dataloader = torch.utils.data.DataLoader(
5759
data_module["train_dataset"],
@@ -61,42 +63,56 @@ def train(rank, args):
6163
collate_fn=DataCollatorWithPadding(max_length=args.training_seq_len),
6264
drop_last=True,
6365
)
64-
65-
trainer = EagleTPTrainer(rank, args, tokenizer, train_dataloader)
66+
trainer_cls = {
67+
"sglang": EagleSGLTrainer,
68+
"hf": EagleTPTrainer,
69+
}[args.teacher_backend]
70+
trainer = trainer_cls(rank, args, tokenizer, train_dataloader)
6671
trainer.train()
67-
trainer.save_pretrained(args.out_path)
72+
trainer.save(args.out_path)
6873

6974

7075
def main():
7176
parser = argparse.ArgumentParser(description="Multi-GPU distributed two-stage forward example")
77+
78+
# Training args
7279
parser.add_argument("--model_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
73-
parser.add_argument("--student_devices", type=list, default=[0, 1, 2, 3])
74-
parser.add_argument("--teacher_devices", type=list, default=[4, 5, 6, 7])
75-
parser.add_argument(
76-
"--data_path", type=str, default="data/magpie_llama3.2_1b_generated/data.cleaned.jsonl"
77-
)
80+
parser.add_argument("--data_path", type=str, required=True, help="Training dataset.")
7881
parser.add_argument("--training_seq_len", type=str, default=1024)
7982
parser.add_argument("--eagle_config_path", type=str, default="eagle_config.json")
80-
parser.add_argument(
81-
"--lazy_preprocess", type=bool, default=True, help="Whether to use lazy preprocessing."
82-
)
8383
parser.add_argument("--out_path", type=str, default="ckpts/fast-trained")
8484
parser.add_argument("--lr", type=float, default=1e-5)
8585
parser.add_argument("--epoch", type=int, default=1)
86+
parser.add_argument("--batch_size", type=int, default=8, help="Total bs across all ranks.")
87+
88+
# Trainer args
89+
parser.add_argument("--teacher_backend", type=str, choices=["sglang", "hf"], default="sglang")
90+
parser.add_argument(
91+
"--teacher_ep_size",
92+
type=int,
93+
default=1,
94+
help="Teacher EP size, only used for sglang backend.",
95+
)
96+
parser.add_argument("--teacher_devices", type=list, default=[0, 1, 2, 3])
97+
parser.add_argument("--student_devices", type=list, default=[4, 5, 6, 7])
98+
parser.add_argument(
99+
"--lazy_preprocess", type=bool, default=True, help="Whether to use lazy preprocessing."
100+
)
101+
parser.add_argument("--log_interval", type=int, default=50)
102+
parser.add_argument("--save_interval", type=int, default=20000)
86103
parser.add_argument(
87-
"--batch_size", type=int, default=4, help="Total batch size across all parallel ranks."
104+
"--total_steps", type=int, default=60000, help="Total number of steps for debugging."
88105
)
89106
parser.add_argument("--master_port", type=str, default="12357")
90107

91108
args = parser.parse_args()
92109
# TODO: add sanity check for args
93110

94111
def set_ranks(args):
95-
# TODO(hg): This is for TP-DDP setting only. Add "no-parallel", "MP", "FSDP".
96112
args.world_size = len(args.teacher_devices) + len(args.student_devices)
97-
args.student_ranks = list(range(len(args.student_devices)))
98-
args.teacher_ranks = list(
99-
range(len(args.student_devices), len(args.student_devices) + len(args.teacher_devices))
113+
args.teacher_ranks = list(range(len(args.teacher_devices)))
114+
args.student_ranks = list(
115+
range(len(args.teacher_devices), len(args.teacher_devices) + len(args.student_devices))
100116
)
101117

102118
set_ranks(args)

0 commit comments

Comments
 (0)