From 1df1bd0e7a08dc947d47cb3df9a1b43ea61bc1b7 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sun, 16 Nov 2025 17:17:44 -0800 Subject: [PATCH 1/4] squash: new trainer with HF and SGL backend Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/eagle_utils.py | 2 +- .../scripts/ar_validate.py | 7 +- examples/speculative_decoding/train.py | 129 +++++ .../trainer/distill_trainer.py | 508 ++++++++++++++++++ .../trainer/sgl_wrapper.py | 208 +++++++ .../torch/speculative/plugins/transformers.py | 2 +- 6 files changed, 852 insertions(+), 4 deletions(-) create mode 100644 examples/speculative_decoding/train.py create mode 100644 examples/speculative_decoding/trainer/distill_trainer.py create mode 100644 examples/speculative_decoding/trainer/sgl_wrapper.py diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 4c6da77a1..b92e81190 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -498,7 +498,7 @@ def compute_loss(self, *args, **kwargs): kwargs.pop("num_items_in_batch", None) loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs) if hasattr(outputs, "train_acc"): - self.state.training_accs.append(outputs.train_acc) + self.state.training_accs.append([acc.item() for acc in outputs.train_acc]) return loss diff --git a/examples/speculative_decoding/scripts/ar_validate.py b/examples/speculative_decoding/scripts/ar_validate.py index 38b886693..8e3212ac6 100644 --- a/examples/speculative_decoding/scripts/ar_validate.py +++ b/examples/speculative_decoding/scripts/ar_validate.py @@ -26,11 +26,14 @@ mto.enable_huggingface_checkpointing() -def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None): +def validate_ar( + model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None, disable_pbar=False +): validator = HFARValidation(model, tokenizer) num_samples = min(num_samples, len(ds)) ars = [] - for i in tqdm(range(num_samples), desc="Validating AR"): + print("validating AR...") + for i in tqdm(range(num_samples), disable=disable_pbar): prompt = ds[i]["prompt"][0] input_ids = tokenizer(prompt, return_tensors="pt").input_ids # Apply chat template to the prompt, continuing with assistant response diff --git a/examples/speculative_decoding/train.py b/examples/speculative_decoding/train.py new file mode 100644 index 000000000..3d21d0d15 --- /dev/null +++ b/examples/speculative_decoding/train.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from eagle_utils import DataCollatorWithPadding, make_eagle_supervised_data_module +from trainer.distill_trainer import EagleSGLTrainer, EagleTPTrainer +from transformers import AutoTokenizer + +torch.manual_seed(0) + + +def _setup_distributed(rank, args, backend="nccl"): + """Initialize distributed environment""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = args.master_port + os.environ["LOCAL_RANK"] = str(rank) + # Initialize process group + dist.init_process_group(backend, rank=rank, world_size=args.world_size) + if rank in args.teacher_ranks: + torch.cuda.set_device(args.teacher_devices[rank]) + else: + torch.cuda.set_device(args.student_devices[rank - len(args.teacher_ranks)]) + print( + f"Starting process rank={rank}, device={torch.cuda.current_device()}, world_size={args.world_size}" + ) + args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks) + args.student_pgroup = dist.new_group(ranks=args.student_ranks) + + +def train(rank, args): + _setup_distributed(rank, args) + + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, model_max_length=args.training_seq_len + ) + args.use_offline_training = False + args.vlm_processor = None + args.offline_data_path = None + data_module = make_eagle_supervised_data_module(tokenizer, args) + + train_dataloader = torch.utils.data.DataLoader( + data_module["train_dataset"], + batch_size=args.batch_size, + shuffle=True, + num_workers=0, + collate_fn=DataCollatorWithPadding(max_length=args.training_seq_len), + drop_last=True, + ) + trainer_cls = { + "sglang": EagleSGLTrainer, + "hf": EagleTPTrainer, + }[args.teacher_backend] + trainer = trainer_cls(rank, args, tokenizer, train_dataloader) + trainer.train() + trainer.save(args.out_path) + + +def main(): + parser = argparse.ArgumentParser(description="Multi-GPU distributed two-stage forward example") + + # Training args + parser.add_argument("--model_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + parser.add_argument("--data_path", type=str, required=True, help="Training dataset.") + parser.add_argument("--training_seq_len", type=str, default=1024) + parser.add_argument("--eagle_config_path", type=str, default="eagle_config.json") + parser.add_argument("--out_path", type=str, default="ckpts/fast-trained") + parser.add_argument("--lr", type=float, default=1e-5) + parser.add_argument("--epoch", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=8, help="Total bs across all ranks.") + + # Trainer args + parser.add_argument("--teacher_backend", type=str, choices=["sglang", "hf"], default="sglang") + parser.add_argument( + "--teacher_ep_size", + type=int, + default=1, + help="Teacher EP size, only used for sglang backend.", + ) + parser.add_argument("--teacher_devices", type=list, default=[0, 1, 2, 3]) + parser.add_argument("--student_devices", type=list, default=[4, 5, 6, 7]) + parser.add_argument( + "--lazy_preprocess", type=bool, default=True, help="Whether to use lazy preprocessing." + ) + parser.add_argument("--log_interval", type=int, default=50) + parser.add_argument("--save_interval", type=int, default=20000) + parser.add_argument( + "--total_steps", type=int, default=60000, help="Total number of steps for debugging." + ) + parser.add_argument("--master_port", type=str, default="12357") + + args = parser.parse_args() + # TODO: add sanity check for args + + def set_ranks(args): + args.world_size = len(args.teacher_devices) + len(args.student_devices) + args.teacher_ranks = list(range(len(args.teacher_devices))) + args.student_ranks = list( + range(len(args.teacher_devices), len(args.teacher_devices) + len(args.student_devices)) + ) + + set_ranks(args) + # Launch multiple processes + mp.spawn( + train, + args=(args,), + nprocs=args.world_size, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/speculative_decoding/trainer/distill_trainer.py b/examples/speculative_decoding/trainer/distill_trainer.py new file mode 100644 index 000000000..dd45532d2 --- /dev/null +++ b/examples/speculative_decoding/trainer/distill_trainer.py @@ -0,0 +1,508 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +from abc import abstractmethod +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.optimization import get_linear_schedule_with_warmup +from transformers.utils import ModelOutput + +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG +from modelopt.torch.utils import print_rank_0 + +from .sgl_wrapper import SglangTargetModel + +try: + import wandb +except ImportError: + wandb = None + + +mto.enable_huggingface_checkpointing() + +# Shape and dtype description of the distillation signal +DistillMetadata = dict[str, tuple[torch.Size, torch.dtype]] + + +class BaseDistillTrainer: + """ + Base distill trainer. + Designed as a placement for HF trainer for several purposes: + 1. Allow separate placement and parallelism for teacher and student. + 2. Allow overlapped teacher and student steps. + 3. Clean, minimal training loop to reduce compatibility issues. + Args: + rank: rank of the current process + args: arguments + teacher_step: teacher step function. + student_step: student step function. + """ + + def __init__(self, rank, args, tokenizer, dataloader): + self.rank = rank + self.args = args + self.tokenizer = tokenizer + self.dataloader = dataloader + self.logs = {} + + # Prepare models + if rank in args.student_ranks: + self.model = self._prepare_student_model() + self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr) + # Same scheduler as HF trainer default + self.scheduler = get_linear_schedule_with_warmup( + self.optimizer, num_warmup_steps=0, num_training_steps=self.args.total_steps + ) + else: + self.model = self._prepare_teacher_model() + dist.barrier() + self._print_model_placement(self.model) + + def _print_model_placement(self, module): + for name, param in module.named_parameters(): + print(f"(Rank {self.rank}) {name:<60} --> {param.device}") + + def _reset_all_mem_stats(self): + torch.cuda.reset_max_memory_allocated(self.current_rank_device) + + def _print_mem_stats(self): + max_mem = torch.cuda.max_memory_allocated(self.current_rank_device) + print(f"GPU {self.current_rank_device}: Max memory allocated: {max_mem / 1024**3:.2f} GB") + + @property + def current_rank_device(self): + """Return device of the current rank.""" + + @property + def distill_metadata(self): + """Return a DistillMetadata that describe the distillation message received by student.""" + + @abstractmethod + def _prepare_teacher_model(self): + """Return coverted teacher model with correct parallelization.""" + + @abstractmethod + def _prepare_student_model(self): + """Return coverted student model with correct parallelization.""" + + @abstractmethod + def teacher_step(self, *args, **kwargs) -> list[dict[str, torch.Tensor]]: + """Run one student step and return distillation messages for each student rank.""" + + @abstractmethod + def student_step(self, *args, **kwargs) -> ModelOutput: + """Run forward of student step, return a modeloutput object.""" + + def save(self, save_path): + """Save training ckpt on first student rank.""" + if self.rank != self.args.student_ranks[0]: + return + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + self.model.module.save_pretrained(save_path) + else: + self.model.save_pretrained(save_path) + self.tokenizer.save_pretrained(save_path) + torch.save(self.optimizer.state_dict(), f"{save_path}/optimizer.pt") + torch.save(self.scheduler.state_dict(), f"{save_path}/scheduler.pt") + print_rank_0(f"Training ckpt saved to {save_path}") + + def _check_valid_message(self, message: dict[str, torch.Tensor]): + """Check if message in the format of distill_metadata.""" + if set(message.keys()) != set(self.distill_metadata.keys()): + raise ValueError( + f"Message keys: {set(message.keys())} \n" + f"do not match expected keys {set(self.distill_metadata.keys())}" + ) + if len(message) != len(self.distill_metadata): + raise ValueError( + f"Message length: {len(message)} \n" + f"does not match expected {len(self.distill_metadata)}" + ) + for k, v in message.items(): + if v.shape != self.distill_metadata[k][0] or v.dtype != self.distill_metadata[k][1]: + raise ValueError( + f"Invalid message. {k} has shape {v.shape} and dtype {v.dtype}, \n" + f"expected {self.distill_metadata[k]}" + ) + + def _init_student_recv_buffer(self): + """Init buffer for receiving distillation messages from teacher.""" + self.student_recv_buffer = { + k: torch.empty(v[0], device=self.current_rank_device, dtype=v[1]) + for k, v in self.distill_metadata.items() + } + + def _recv_from_teacher(self): + reqs = [ + dist.irecv(buffer, src=self.args.teacher_ranks[0]) + for buffer in self.student_recv_buffer.values() + ] + for req in reqs: + req.wait() + + def _clone_recv_buffer(self): + """Return a copy of received tensors for student step input.""" + return {k: v.clone().detach() for k, v in self.student_recv_buffer.items()} + + def _send_to_student(self, teacher_outputs): + if self.rank != self.args.teacher_ranks[0]: + return + # TODO: use broadcast + assert len(teacher_outputs) == len(self.args.student_ranks), ( + f"Number of teacher outputs {len(teacher_outputs)} does not \ + match number of student ranks {len(self.args.student_ranks)}" + ) + for idx, s in enumerate(self.args.student_ranks): + self._check_valid_message(teacher_outputs[idx]) + reqs = [dist.isend(buffer, dst=s) for buffer in teacher_outputs[idx].values()] + for req in reqs: + req.wait() + + def _get_logging_context(self): + if wandb is not None and self.rank == self.args.student_ranks[0]: + return wandb.init( + config={ + "epochs": self.args.epoch, + "lr": self.args.lr, + "batch_size": self.args.batch_size, + }, + ) + return nullcontext() + + def train(self): + """Main training entrance of the composed model.""" + self._reset_all_mem_stats() + + if self.rank in self.args.student_ranks: + with self._get_logging_context() as run: + self._init_student_recv_buffer() + + # Student training loop + for epoch in range(self.args.epoch): + pbar = ( + tqdm(self.dataloader) + if self.rank == self.args.student_ranks[0] + else self.dataloader + ) + for i, batch in enumerate(pbar): + global_step = epoch * len(self.dataloader) + i + if global_step >= self.args.total_steps: + break + inputs = {k: v.to(self.model.device) for k, v in batch.items()} + + # Receive distill messages from teacher + self._recv_from_teacher() + + # Run forward of student step + output = self.student_step(inputs, **self._clone_recv_buffer()) + loss = output.loss + + # Run backward step + loss.backward() + self.optimizer.step() + self.scheduler.step() + + # Log and save only on student rank 0 + if self.rank != self.args.student_ranks[0]: + continue + + train_metrics = { + "loss": round(loss.item(), 3), + "lr": self.optimizer.param_groups[0]["lr"], + # Attach all float metrics + **{k: round(v, 3) for k, v in output.items() if isinstance(v, float)}, + } + if "train_acc" in output: + train_metrics.update( + { + f"train_acc_step{i}": output["train_acc"][i].item() + for i in range(len(output["train_acc"])) + } + ) + + pbar.set_description(f"Epoch {epoch} Loss {train_metrics['loss']}") + + # Add train_metrics into self.logs as a dict of lists of metrics since last log + for key, value in train_metrics.items(): + if key not in self.logs: + self.logs[key] = [] + self.logs[key].append(value) + + if global_step % self.args.log_interval == 0: + run.log( + {k: sum(v) / len(v) for k, v in self.logs.items()}, step=global_step + ) + self.logs = {} + if global_step > 0 and global_step % self.args.save_interval == 0: + self.save(f"{self.args.out_path}/epoch_{epoch}_step_{global_step}") + + else: + # Inference Loop + for epoch in range(self.args.epoch): + for i, batch in enumerate(self.dataloader): + global_step = epoch * len(self.dataloader) + i + if global_step >= self.args.total_steps: + break + inputs = {k: v.to(self.model.device) for k, v in batch.items()} + with torch.inference_mode(): + self._send_to_student(self.teacher_step(self.model, inputs)) + + self._print_mem_stats() + # Makesure all processes finished before destroy. + dist.barrier() + # clean up processess + dist.destroy_process_group() + + +class EagleTPTrainer(BaseDistillTrainer): + """A subclass of BaseDistillTrainer for online eagle training, with base model TP and student DDP.""" + + def __init__(self, rank, args, tokenizer, dataloader): + # Load eagle config + args.eagle_config = EAGLE3_DEFAULT_CFG["config"] + if args.eagle_config_path: + with open(args.eagle_config_path) as f: + custom_config = json.load(f) + args.eagle_config["eagle_architecture_config"].update(custom_config) + + super().__init__(rank, args, tokenizer, dataloader) + + @property + def current_rank_device(self): + if self.rank in self.args.teacher_ranks: + return self.args.teacher_devices[self.rank] + else: + return self.args.student_devices[self.rank - len(self.args.teacher_ranks)] + + def _prepare_teacher_model(self): + # Load model with TP among teacher ranks. + model = AutoModelForCausalLM.from_pretrained( + self.args.model_path, + torch_dtype="auto", + tp_plan="auto", + device_mesh=DeviceMesh.from_group(self.args.teacher_pgroup, "cuda"), + ) + # load eagle config and convert. + self.args.eagle_config["eagle_architecture_config"].update( + { + "hidden_size": model.config.hidden_size, + "vocab_size": model.config.vocab_size, + "draft_vocab_size": model.config.vocab_size, + } + ) + mtsp.convert(model, [("eagle", self.args.eagle_config)]) + model.eval() + return model + + def _prepare_student_model(self): + # Load to CPU first to avoid OOM + model = AutoModelForCausalLM.from_pretrained( + self.args.model_path, torch_dtype="auto", device_map="cpu" + ) + # Hidden size and vocab size must match base model + self.args.eagle_config["eagle_architecture_config"].update( + { + "hidden_size": model.config.hidden_size, + "vocab_size": model.config.vocab_size, + "draft_vocab_size": model.config.vocab_size, + } + ) + mtsp.convert( + model, + [("eagle", self.args.eagle_config)], + ) + + # TODO:copy needed modules and del the rest + model.model._modules.pop("layers") + model.to(self.current_rank_device) + + model.train() + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.current_rank_device], + process_group=self.args.student_pgroup, + find_unused_parameters=True, + ) + return model + + @property + def distill_metadata(self) -> DistillMetadata: + """Description of the distillation signal received by student.""" + return { + "base_model_hidden_states": ( + torch.Size( + [ + int(self.args.batch_size / len(self.args.student_ranks)), + self.args.training_seq_len, + self.args.eagle_config["eagle_architecture_config"]["hidden_size"], + ] + ), + torch.bfloat16, + ), + "aux_hidden_states": ( + torch.Size( + [ + int(self.args.batch_size / len(self.args.student_ranks)), + self.args.training_seq_len, + self.args.eagle_config["eagle_architecture_config"]["hidden_size"] * 3, + ] + ), + torch.bfloat16, + ), + "base_model_logits": ( + torch.Size( + [ + int(self.args.batch_size / len(self.args.student_ranks)), + self.args.training_seq_len, + self.args.eagle_config["eagle_architecture_config"]["draft_vocab_size"], + ] + ), + torch.bfloat16, + ), + } + + def teacher_step(self, model, inputs): + # Collect base model outputs. + base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward( + **inputs, + freeze_base_model=True, + past_key_values=None, + ) + + # Aux_hidden_states could be on multiple devices. Gather before cat. + aux_hidden_states = torch.cat( + [t.to(base_model_logits.device) for t in model.pop_and_gather_aux_hiddens()], dim=-1 + ) + + # Chunk the tensors for each student rank. + base_model_hidden_states = base_model_hidden_states.chunk(len(self.args.student_ranks)) + base_model_logits = base_model_logits.chunk(len(self.args.student_ranks)) + aux_hidden_states = aux_hidden_states.chunk(len(self.args.student_ranks)) + + return [ + { + "base_model_hidden_states": base_model_hidden_states[i], + "aux_hidden_states": aux_hidden_states[i], + "base_model_logits": base_model_logits[i], + } + for i in range(len(self.args.student_ranks)) + ] + + def student_step( + self, + inputs, + **distill_msgs, + ) -> ModelOutput: + self.optimizer.zero_grad() + + # Chunk input_ids and attention_mask for each student rank. + student_idx = self.rank - len(self.args.teacher_ranks) + inputs = {k: v.chunk(len(self.args.student_ranks))[student_idx] for k, v in inputs.items()} + + # Second stage forward with provided base model outputs. + output = self.model(**inputs, base_model_outputs=distill_msgs) + + return output + + +class EagleSGLTrainer(EagleTPTrainer): + """A subclass of EagleTPTrainer for online eagle training, with base model SGL and student DDP.""" + + def _prepare_teacher_model(self): + args = self.args + args.tp_size = len(self.args.teacher_devices) + args.max_length = self.args.training_seq_len + + teacher_config = AutoConfig.from_pretrained(args.model_path) + self.args.eagle_config["eagle_architecture_config"].update( + { + "hidden_size": teacher_config.hidden_size, + "vocab_size": teacher_config.vocab_size, + "draft_vocab_size": teacher_config.vocab_size, + } + ) + + # patch torch.distributed functions to use only partial ranks + original_get_world_size = torch.distributed.get_world_size + original_barrier = torch.distributed.barrier + torch.distributed.get_world_size = lambda *args, **kwargs: len(self.args.teacher_devices) + + def barrier_patch(*args, **kwargs): + if not args and not kwargs: + original_barrier(group=self.args.teacher_pgroup) + else: + original_barrier(*args, **kwargs) + + torch.distributed.barrier = barrier_patch + + # load SGL model with patches + model = SglangTargetModel( + args=args, + tp_group=self.args.teacher_pgroup, + return_full_logits=True, + gpu_id=self.current_rank_device, + ) + + # retore patches + torch.distributed.get_world_size = original_get_world_size + torch.distributed.barrier = original_barrier + + model.set_aux_hidden_states_layers() + print("rank", self.rank, "SGL base model loaded") + model.device = self.current_rank_device + return model + + def teacher_step(self, model, inputs): + # TODO: handle data loading in preprocess + sgl_inputs = [ + { + "input_ids": inputs["input_ids"][i], + "attention_mask": inputs["attention_mask"][i], + "loss_mask": inputs["loss_mask"][i], + } + for i in range(inputs["input_ids"].shape[0]) + ] + + logits, h, aux_h = model.forward(sgl_inputs) + + # Chunk the tensors for each student rank. + base_model_logits = logits.chunk(len(self.args.student_ranks)) + base_model_hidden_states = h.chunk(len(self.args.student_ranks)) + aux_hidden_states = aux_h.chunk(len(self.args.student_ranks)) + + seq_len = inputs["input_ids"].shape[1] + vocab_size = logits.shape[-1] + hid_size = h.shape[-1] + + return [ + { + "base_model_hidden_states": base_model_hidden_states[i].view(-1, seq_len, hid_size), + "aux_hidden_states": aux_hidden_states[i].view(-1, seq_len, hid_size * 3), + "base_model_logits": base_model_logits[i] + .view(-1, seq_len, vocab_size) + .to(dtype=torch.bfloat16), + } + for i in range(len(self.args.student_ranks)) + ] diff --git a/examples/speculative_decoding/trainer/sgl_wrapper.py b/examples/speculative_decoding/trainer/sgl_wrapper.py new file mode 100644 index 000000000..cd8cc7dcb --- /dev/null +++ b/examples/speculative_decoding/trainer/sgl_wrapper.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file contains the wrapper for the SGL model. +""" + +import torch +import torch.distributed as dist +import torch.nn as nn +from sglang.bench_one_batch import BenchArgs, _maybe_prepare_mlp_sync_batch, load_model +from sglang.srt.entrypoints.engine import _set_envs_and_config +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.mem_cache.radix_cache import RadixCache +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import configure_logger +from transformers import AutoConfig + +torch.manual_seed(42) + + +class LogitsProcessorForEAGLE3(torch.nn.Module): + def __init__(self, logits_processor: LogitsProcessor, return_full_logits: bool = False): + super().__init__() + self.logits_processor = logits_processor + self.return_full_logits = return_full_logits + + def forward( + self, + input_ids, + hidden_states, + lm_head, + logits_metadata, + aux_hidden_states: torch.Tensor | None = None, + ) -> LogitsProcessorOutput: + if self.return_full_logits: + logits_metadata.forward_mode = ForwardMode.DECODE + ret = self.logits_processor.forward( + input_ids, hidden_states, lm_head, logits_metadata, aux_hidden_states + ) + ret.last_hidden_states = hidden_states + return ret + + +def wrap_logits_processors_in_module(module: nn.Module, return_full_logits: bool = False): + for name, submodule in module.named_modules(): + if isinstance(submodule, LogitsProcessor): + wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits) + setattr(module, name, wrapped) + print(f"wrapped {name} with LogitsProcessorForEAGLE3") + + +class SglangTargetModel(nn.Module): + def __init__( + self, + args, + tp_group, + gpu_id, + return_full_logits=False, + ): + super().__init__() + self.return_full_logits = return_full_logits + tp_rank = dist.get_rank(group=tp_group) if dist.is_initialized() else 0 + self.tp_group = tp_group + self.tp_rank = tp_rank + self.args = args + self.bench_args = BenchArgs() + self.server_args = ServerArgs(model_path=args.model_path) + self.server_args.enable_return_hidden_states = True + self.server_args.context_length = args.max_length + self.server_args.tp_size = args.tp_size + self.server_args.ep_size = args.teacher_ep_size + + self.server_args.cuda_graph_max_bs = max(self.bench_args.batch_size) + self.server_args.cuda_graph_bs = list(self.bench_args.batch_size) + _set_envs_and_config(self.server_args) + self.port_args = PortArgs.init_new(self.server_args) + configure_logger(self.server_args, prefix=f" TP{tp_rank}") + self.model_runner, _ = load_model(self.server_args, self.port_args, gpu_id, tp_rank) + wrap_logits_processors_in_module(self.model_runner.model, return_full_logits) + + def set_aux_hidden_states_layers(self, aux_hidden_states_layers=None): + config = AutoConfig.from_pretrained( + self.server_args.model_path, + trust_remote_code=self.server_args.trust_remote_code, + ) + if aux_hidden_states_layers is None: + if hasattr(config, "num_hidden_layers"): + num_layers = config.num_hidden_layers + elif hasattr(config, "text_config"): + num_layers = config.text_config.num_hidden_layers + else: + raise ValueError( + f"config {config} does not have num_hidden_layers or text_config.num_hidden_layers" + ) + # in sglang, when we do set_eagle3_layers_to_capture, we will add 1 to the layer index + aux_hidden_states_layers = [ + 2 - 1, + num_layers // 2 - 1, + num_layers - 3 - 1, + ] + self.aux_hidden_states_layers = aux_hidden_states_layers + assert len(self.aux_hidden_states_layers) == 3, ( + "aux_hidden_states_layers is expected to be 3 layers" + ) + print(f"Capturing Aux hidden states layers: {self.aux_hidden_states_layers}") + + if not hasattr(self.model_runner.model, "set_eagle3_layers_to_capture"): + raise ValueError( + f"model_runner.model {self.model_runner.model} does not have set_eagle3_layers_to_capture" + ) + self.model_runner.model.set_eagle3_layers_to_capture(self.aux_hidden_states_layers) + if hasattr(self.model_runner.model, "capture_aux_hidden_states"): + assert self.model_runner.model.capture_aux_hidden_states, ( + "model_runner.model.capture_aux_hidden_states is expected to be True" + ) + elif hasattr(self.model_runner.model.language_model, "capture_aux_hidden_states"): + assert self.model_runner.model.language_model.capture_aux_hidden_states, ( + "model_runner.model.capture_aux_hidden_states is expected to be True" + ) + else: + raise ValueError( + f"model_runner.model {self.model_runner.model} does not have capture_aux_hidden_states" + ) + + @torch.no_grad + def extend(self, reqs: list[Req]): + tree_cache = RadixCache( + None, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + page_size=self.model_runner.server_args.page_size, + ) + batch = ScheduleBatch.init_new( + reqs=reqs, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + tree_cache=tree_cache, + model_config=self.model_runner.model_config, + enable_overlap=False, + spec_algorithm=SpeculativeAlgorithm.NONE, + ) + batch.prepare_for_extend() + _maybe_prepare_mlp_sync_batch(batch, self.model_runner) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch.capture_hidden_mode = CaptureHiddenMode.FULL + logits_output, _ = self.model_runner.forward(forward_batch) + input_lens = [len(req.origin_input_ids) for req in reqs] + assert logits_output.last_hidden_states.shape[0] == sum(input_lens), ( + "the number of hidden states is not correct" + ) + assert logits_output.hidden_states.shape[0] == sum(input_lens), ( + "the number of hidden states is not correct" + ) + self.model_runner.req_to_token_pool.clear() + self.model_runner.token_to_kv_pool_allocator.clear() + return ( + logits_output.next_token_logits, + logits_output.last_hidden_states, + logits_output.hidden_states, + ) + + def forward( + self, + data_for_target: list[dict[str, torch.Tensor]], + ): + """ + arguments: + data_for_target: List[Dict[str, torch.Tensor]] of target_batch_size + - input_ids: (tp_size, seq_len) + - attention_mask: (tp_size, seq_len) + - loss_mask: (tp_size, seq_len) + """ + sampling_params = SamplingParams(temperature=0, max_new_tokens=1, top_k=1) + reqs = [] + for idx, data in enumerate(data_for_target): + req = Req( + rid=str(idx), + origin_input_text="", + origin_input_ids=data["input_ids"].view(-1).tolist(), + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 + reqs.append(req) + logits, hidden_states_list, aux_hidden_states_list = self.extend(reqs) + return logits, hidden_states_list, aux_hidden_states_list diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 1aed13e87..845eb4974 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -916,7 +916,7 @@ def _eagle_loss( valid = loss_mask[:, :, 0].bool() correct = (base_predict_tok == eagle_predict_tok) & valid denom = valid.sum().clamp_min(1).float() - accuracy = round(correct.sum().float().div(denom).item(), 3) + accuracy = correct.sum().float().div(denom) return classification_loss, accuracy From 6df419f005ded17fd13a5b44bf76468382ae84f9 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sun, 16 Nov 2025 17:43:17 -0800 Subject: [PATCH 2/4] add license Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../trainer/sgl_wrapper.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/examples/speculative_decoding/trainer/sgl_wrapper.py b/examples/speculative_decoding/trainer/sgl_wrapper.py index cd8cc7dcb..8cb2b2319 100644 --- a/examples/speculative_decoding/trainer/sgl_wrapper.py +++ b/examples/speculative_decoding/trainer/sgl_wrapper.py @@ -13,6 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +# MIT License + +# Copyright (c) 2025 sgl-project + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + """ This file contains the wrapper for the SGL model. """ From a22a9482d0b6bb0bb6ab7b871aa7f0dd655e0069 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:48:26 -0800 Subject: [PATCH 3/4] debug:sgl backend; use torchrun Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/train.py | 61 ++++++++++++-------------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/examples/speculative_decoding/train.py b/examples/speculative_decoding/train.py index 3d21d0d15..f4aa8621c 100644 --- a/examples/speculative_decoding/train.py +++ b/examples/speculative_decoding/train.py @@ -14,11 +14,9 @@ # limitations under the License. import argparse -import os import torch import torch.distributed as dist -import torch.multiprocessing as mp from eagle_utils import DataCollatorWithPadding, make_eagle_supervised_data_module from trainer.distill_trainer import EagleSGLTrainer, EagleTPTrainer from transformers import AutoTokenizer @@ -26,27 +24,31 @@ torch.manual_seed(0) -def _setup_distributed(rank, args, backend="nccl"): - """Initialize distributed environment""" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = args.master_port - os.environ["LOCAL_RANK"] = str(rank) - # Initialize process group - dist.init_process_group(backend, rank=rank, world_size=args.world_size) +def _check_args(args): + """Sanity check for arguments.""" + # TODO: (hg) + + +def _setup_pgroups(args): + """Initialize student/teacher pgroups and set devices.""" + rank = dist.get_rank() + args.teacher_ranks = list(range(len(args.teacher_devices))) + args.student_ranks = list( + range(len(args.teacher_devices), len(args.teacher_devices) + len(args.student_devices)) + ) if rank in args.teacher_ranks: torch.cuda.set_device(args.teacher_devices[rank]) else: torch.cuda.set_device(args.student_devices[rank - len(args.teacher_ranks)]) print( - f"Starting process rank={rank}, device={torch.cuda.current_device()}, world_size={args.world_size}" + f"Starting process rank={rank}, device={torch.cuda.current_device()}, world_size={dist.get_world_size()}" ) args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks) args.student_pgroup = dist.new_group(ranks=args.student_ranks) -def train(rank, args): - _setup_distributed(rank, args) - +def train(args): + """Entrance for training.""" tokenizer = AutoTokenizer.from_pretrained( args.model_path, model_max_length=args.training_seq_len ) @@ -55,6 +57,10 @@ def train(rank, args): args.offline_data_path = None data_module = make_eagle_supervised_data_module(tokenizer, args) + # Ensure different ranks load the same data + g = torch.Generator() + g.manual_seed(0) + train_dataloader = torch.utils.data.DataLoader( data_module["train_dataset"], batch_size=args.batch_size, @@ -62,12 +68,13 @@ def train(rank, args): num_workers=0, collate_fn=DataCollatorWithPadding(max_length=args.training_seq_len), drop_last=True, + generator=g, ) trainer_cls = { "sglang": EagleSGLTrainer, "hf": EagleTPTrainer, }[args.teacher_backend] - trainer = trainer_cls(rank, args, tokenizer, train_dataloader) + trainer = trainer_cls(dist.get_rank(), args, tokenizer, train_dataloader) trainer.train() trainer.save(args.out_path) @@ -76,7 +83,7 @@ def main(): parser = argparse.ArgumentParser(description="Multi-GPU distributed two-stage forward example") # Training args - parser.add_argument("--model_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + parser.add_argument("--model_path", type=str, required=True, help="Target model path.") parser.add_argument("--data_path", type=str, required=True, help="Training dataset.") parser.add_argument("--training_seq_len", type=str, default=1024) parser.add_argument("--eagle_config_path", type=str, default="eagle_config.json") @@ -103,26 +110,16 @@ def main(): parser.add_argument( "--total_steps", type=int, default=60000, help="Total number of steps for debugging." ) + parser.add_argument("--master_addr", type=str, default="localhost") parser.add_argument("--master_port", type=str, default="12357") args = parser.parse_args() - # TODO: add sanity check for args - - def set_ranks(args): - args.world_size = len(args.teacher_devices) + len(args.student_devices) - args.teacher_ranks = list(range(len(args.teacher_devices))) - args.student_ranks = list( - range(len(args.teacher_devices), len(args.teacher_devices) + len(args.student_devices)) - ) - - set_ranks(args) - # Launch multiple processes - mp.spawn( - train, - args=(args,), - nprocs=args.world_size, - join=True, - ) + + dist.init_process_group("nccl") + + _check_args(args) + _setup_pgroups(args) + train(args) if __name__ == "__main__": From 70515dc8e8fe00fa4cf0edbc3db744284e675010 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Fri, 21 Nov 2025 01:07:22 +0000 Subject: [PATCH 4/4] minor fix Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- modelopt/torch/speculative/plugins/transformers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 845eb4974..6d95ae48b 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -663,10 +663,10 @@ def _base_model_forward( self, input_ids, attention_mask, - position_ids, - past_key_values, - freeze_base_model, - labels, + position_ids=None, + past_key_values=None, + freeze_base_model=True, + labels=None, **kwargs, ): # TODO: This function still use eagle_module. Ideally we should remove it,