diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000..0c5e58c --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,111 @@ +import torch +from transformers import LlamaConfig + +from speculators.train.eagle3.core import Eagle3DraftModel, Eagle3VerifierLMHead +from speculators.train.data import Eagle3SampleFileDataset, create_collate_fn +from speculators.train.distributed_batch_sampler import ( + MultipackDistributedBatchSamplerV2, +) +from torch.utils.data import DataLoader + +from speculators.train.utils import maybe_setup_distributed, maybe_destroy_distributed +from speculators.train.trainer import Trainer +from speculators.train.logger import setup_metric_logger, setup_root_logger + + +local_rank, world_size, rank, is_distributed = maybe_setup_distributed() + + +DEVICE = torch.device(local_rank) +EPOCHS = 10 +draft_vocab_size = 5000 +total_seq_len = 5120 +datapath = "./data" +verifier_model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct" + + +# TEMP MODEL SETUP +llama_config = LlamaConfig.from_pretrained(verifier_model_name_or_path) +llama_config._attn_implementation = "simple_flex_attention" +hidden_size = llama_config.hidden_size +verifier_vocab_size = llama_config.vocab_size + +# d2t_vocab = torch.zeros(draft_vocab_size, dtype=torch.long).to(DEVICE) +# t2d_vocab = ( +# torch.cat( +# [ +# torch.ones(draft_vocab_size), +# torch.zeros(llama_config.vocab_size - draft_vocab_size), +# ] +# ) +# .to(torch.bool) +# .to(DEVICE) +# ) +d2t_vocab = torch.load("d2t.npy").to(DEVICE) +t2d_vocab = torch.load("t2d.npy").to(DEVICE) + +setup_metric_logger(loggers="trackio", run_name=None, output_dir="./logs") +setup_root_logger() +# END TEMP MODEL SETUP + +draft_model = Eagle3DraftModel( + verifier_model_name_or_path=verifier_model_name_or_path, + hidden_size=hidden_size, + t2d_vocab=t2d_vocab, + d2t_vocab=d2t_vocab, + decoder_layer_config=llama_config, + verifier_vocab_size=verifier_vocab_size, + verifier_pad_token_id=None, + num_layers=1, + ttt_steps=3, +) + +verifier_lm_head = Eagle3VerifierLMHead( + hidden_size=hidden_size, draft_vocab_size=draft_vocab_size +) +verifier_lm_head.load_verifier_lm_head(verifier_model_name_or_path, t2d_vocab) + +dataset = Eagle3SampleFileDataset(datapath=datapath, max_len=total_seq_len) +batch_sampler = MultipackDistributedBatchSamplerV2( + batch_max_length=total_seq_len, + lengths=dataset.approx_lengths(), + num_replicas=world_size, + rank=local_rank, +) +train_loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + num_workers=16, + pin_memory=True, + collate_fn=create_collate_fn(total_seq_len), +) + + +# todo: make config better +config = { + "num_epochs": EPOCHS, + "save_path": "./checkpoints", + "lr": 1e-5, + "total_seq_len": total_seq_len, + "datapath": "./data", + "resume_from_checkpoint": True, +} + + +trainer = Trainer( + draft_model, + verifier_lm_head, + config, + train_loader, + None, + is_distributed, + local_rank, + world_size, +) +trainer.run_training() + +maybe_destroy_distributed() + + +# RUN WITH: +# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=4 scripts/train.py diff --git a/src/speculators/train/checkpointer.py b/src/speculators/train/checkpointer.py new file mode 100644 index 0000000..c661d5b --- /dev/null +++ b/src/speculators/train/checkpointer.py @@ -0,0 +1,149 @@ +from abc import abstractmethod +import os +from pathlib import Path +import torch +from torch.distributed.checkpoint.state_dict import ( + get_model_state_dict, + get_optimizer_state_dict, + set_model_state_dict, + set_optimizer_state_dict, + StateDictOptions, +) + +import torch.distributed as dist + + +class BaseCheckpointer: + """Helper class to save and load checkpoints. + + Checkpoint file structure: + ../path/ + 0/ # epoch number + model_state_dict.pt + optimizer_state_dict.pt + 1/ + model_state_dict.pt + optimizer_state_dict.pt + ... + """ + + def __init__(self, path: Path | str, try_load_last_checkpoint: bool = True): + self.path = Path(path) + if try_load_last_checkpoint: + self.previous_epoch: int = self._get_previous_epoch() + else: + self.previous_epoch: int = -1 + + @abstractmethod + def load_model_state_dict(self, model: torch.nn.Module): + raise NotImplementedError + + @abstractmethod + def load_optimizer_state_dict( + self, model: torch.nn.Module, optimizer: torch.optim.Optimizer + ): + raise NotImplementedError + + @abstractmethod + def save_checkpoint( + self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int + ): + raise NotImplementedError + + def _get_previous_epoch(self) -> int: + if not self.path.exists(): + return -1 + last_checkpoint_num = -1 + for d in self.path.iterdir(): + if d.is_dir(): + try: + last_checkpoint_num = max(last_checkpoint_num, int(d.name)) + except ValueError: + continue + return last_checkpoint_num + + def model_path(self, epoch: int): + model_fname = "model_state_dict.pt" + return self.path / str(epoch) / model_fname + + def optimizer_path(self, epoch: int): + optimizer_fname = "optimizer_state_dict.pt" + return self.path / str(epoch) / optimizer_fname + + +class SingleGPUCheckpointer(BaseCheckpointer): + def load_model_state_dict(self, model: torch.nn.Module): + full_state_dict = torch.load( + self.model_path(self.previous_epoch), + weights_only=True, + map_location="cuda:0", # todo: make this configurable + ) + model.load_state_dict(full_state_dict) + + def load_optimizer_state_dict( + self, model: torch.nn.Module, optimizer: torch.optim.Optimizer + ): + full_state_dict = torch.load( + self.optimizer_path(self.previous_epoch), + weights_only=True, + map_location="cuda:0", # todo: make this configurable + ) + optimizer.load_state_dict(full_state_dict) + + def save_checkpoint( + self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int + ): + os.makedirs(self.path / str(epoch), exist_ok=True) + torch.save(model.state_dict(), self.model_path(epoch)) + torch.save(optimizer.state_dict(), self.optimizer_path(epoch)) + + +class DistributedCheckpointer(BaseCheckpointer): + def load_model_state_dict(self, model: torch.nn.Module): + full_state_dict = torch.load( + self.model_path(self.previous_epoch), + mmap=True, + weights_only=True, + map_location="cpu", + ) + set_model_state_dict( + model, + full_state_dict, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + ) + dist.barrier() + + def load_optimizer_state_dict(self, model, optimizer: torch.optim.Optimizer): + full_state_dict = torch.load( + self.optimizer_path(self.previous_epoch), + mmap=True, + weights_only=True, + map_location="cpu", + ) + set_optimizer_state_dict( + model, + optimizer, + full_state_dict, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + ) + dist.barrier() + + def save_checkpoint( + self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int + ): + model_state_dict = get_model_state_dict( + model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ) + optimizer_state_dict = get_optimizer_state_dict( + model, + optimizer, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + + if dist.get_rank() == 0: + # Only rank 0 saves the checkpoint + os.makedirs(self.path / str(epoch), exist_ok=True) + torch.save(model_state_dict, self.model_path(epoch)) + torch.save(optimizer_state_dict, self.optimizer_path(epoch)) + + dist.barrier() diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py new file mode 100644 index 0000000..81cab01 --- /dev/null +++ b/src/speculators/train/data.py @@ -0,0 +1,188 @@ +from functools import lru_cache +import math +import os +from typing import Any + +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F + +BatchType = dict[str, Any] + + +class AddGaussianNoise: + def __init__(self, mean=0.0, std=0.0): + self.mean = mean + self.std = std + + def __call__(self, data): + tensor = data["hidden_states"] + noise = torch.randn(tensor.size()) * self.std + self.mean + noisy_tensor = tensor + noise + data["hidden_states"] = noisy_tensor + return data + + +class AddUniformNoise: + def __init__(self, std=0.0): + self.std = std + + def __call__(self, data): + tensor = data["hidden_states"] + noise = (torch.rand_like(tensor) - 0.5) * self.std * 512 / tensor.shape[1] + noisy_tensor = tensor + noise + data["hidden_states"] = noisy_tensor + return data + + +def list_files(path): + datapath = [] + for root, _directories, files in os.walk(path): + for file in files: + file_path = os.path.join(root, file) + datapath.append(file_path) + + return datapath + + +def slice_and_pad_to_length(tensor, length): + sliced_tensor = tensor[:length] + padding = [0, 0] * sliced_tensor.dim() + padding[-1] = length - sliced_tensor.shape[0] + return F.pad(sliced_tensor, padding) + +def shift_batch(batch: BatchType): + input_ids = batch["input_ids"] # shape: [seq_len] + # [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9] + hidden_states = batch["hidden_states"] # shape: [seq_len, hidden_size] + # [g0, g1, g2, g3, g4, g5, g6, g7, g8, g9] + verifier_last_hidden_states = batch["verifier_last_hidden_states"] # shape: [seq_len, hidden_size] + # [y0, y1, y2, y3, y4, y5, y6, y7, y8, y9] + loss_mask = batch["loss_mask"] # shape: [seq_len] + # [l0, l1, l2, l3, l4, l5, l6, l7, l8, l9] + lengths = batch["lengths"] # shape: [1] + # [10] + position_ids = batch["position_ids"] # shape: [seq_len] + # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + + # Need to align (x1, g0, y1, l1) + # todo: verify loss mask shift is correct + + # Drop x0, g(-1), y0, l0, reduce seq_len by 1 + + input_ids = input_ids[1:] + hidden_states = hidden_states[:-1] + verifier_last_hidden_states = verifier_last_hidden_states[1:] + loss_mask = loss_mask[1:] + lengths = lengths - 1 + position_ids = position_ids[1:] # Note: position_ids now start at 1 + + return { + "input_ids": input_ids, + "hidden_states": hidden_states, + "verifier_last_hidden_states": verifier_last_hidden_states, + "loss_mask": loss_mask, + "lengths": lengths, + "position_ids": position_ids, + } + +class Eagle3SampleFileDataset(Dataset): + def __init__( + self, + datapath: str, + max_len: int, + transform=None, + hidden_states_dtype=torch.float, + ): + self.data = list_files(datapath) + self.max_len = max_len + self.transform = transform + self.hidden_states_dtype = hidden_states_dtype + + def __len__(self): + return len(self.data) + + @lru_cache(maxsize=1) + def approx_lengths(self): + lengths_0 = self.__getitem__(0)["lengths"] + # this is a single sample so there is only one length + lengths_0 = lengths_0[0].item() + size_0 = os.path.getsize(self.data[0]) + + approx_lengths = [ + math.ceil(os.path.getsize(fname) / size_0 * lengths_0) + for fname in self.data + ] + return approx_lengths + + def __getitem__(self, index) -> BatchType: + data = torch.load(self.data[index]) + + # todo: standardize names during data generation and then remove this + data["hidden_states"] = data["hidden_state"] + data["verifier_last_hidden_states"] = data["target"] + del data["hidden_state"] + del data["target"] + + # todo: standardize dtypes during data generation and then remove this + data = { + k: v.to(self.hidden_states_dtype) if "hidden_states" in k else v + for k, v in data.items() + } + + seq_len = data["input_ids"].shape[0] + # Add lengths tensor + data["lengths"] = torch.tensor([seq_len], dtype=torch.long) + + if self.transform: + data = self.transform(data) + + + data["position_ids"] = torch.arange(seq_len, dtype=torch.long) + # shape: [seq_len] + + # data structure: { + # "hidden_states": [seq_len, 3 * hidden_size], + # "input_ids": [seq_len], + # "verifier_last_hidden_states": [seq_len, hidden_size], + # "loss_mask": [seq_len], + # "lengths": [1], + # "position_ids": [seq_len], + # } + + # Note: shift_batch will reduce seq_len by 1 + data = shift_batch(data) + + return data + + +def create_collate_fn(max_len: int): + def collate_fn(batch: list[BatchType]) -> BatchType: + collated_data = {} + for key in batch[0].keys(): + collated_data[key] = torch.cat([b[key] for b in batch], dim=0) + + if key != "lengths": + collated_data[key] = slice_and_pad_to_length( + collated_data[key], max_len + ).unsqueeze(0) + # shape: [1, max_len, ...] + + # Handle lengths update + lengths = collated_data["lengths"] + new_lengths = [] + cum_length = 0 + for length in lengths: + if length + cum_length >= max_len: + new_lengths.append(max_len - cum_length) + cum_length = max_len + break + new_lengths.append(length) + cum_length += length + if cum_length < max_len: + # Add extra "padded" sample so that sum(new_lengths) == max_len + new_lengths.append(max_len - cum_length) + collated_data["lengths"] = torch.tensor(new_lengths, dtype=torch.long) + return collated_data + + return collate_fn diff --git a/src/speculators/train/distributed_batch_sampler.py b/src/speculators/train/distributed_batch_sampler.py new file mode 100644 index 0000000..58075a0 --- /dev/null +++ b/src/speculators/train/distributed_batch_sampler.py @@ -0,0 +1,210 @@ +""" +MIT License + +Copyright (c) 2023 One + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Adapted from https://github.com/imoneoi/multipack_sampler. +""" + +# Standard +from functools import lru_cache +from heapq import heapreplace +from typing import NamedTuple +import warnings + +# Third Party +from numpy.typing import ArrayLike, NDArray +from torch.utils.data import Sampler +import numpy as np + + +## Multipack Distributed Batch Sampler +class _Bin(NamedTuple): + """Helper named tuple for `lpt_packed_batch`""" + + fill: int # sum of items in _Bin + rank: int # device rank _Bin is associated with + + +def _lpt_packed_batch( + lengths: np.ndarray, max_len: int, num_replicas: int, start_index: int, rank: int +) -> None | list: + """ + Check if lengths can be distributed into `num_replicas` machines with at most `max_len` tokens per machine and return local rank's batch. + + Uses the LPT (Longest processing time first scheduling) algorithm + Time: O(|lengths| log |lengths| + |lengths| log replicas) + + Returns: + `None` if unable to find a valid packing. Otherwise return the batch indices that correspond to `rank`. + """ + + # Greedily assign lengths (in decreasing order) to the least full rank until they are all assigned or + # we run out of space. + local_batch = [] + heap = [_Bin(0, i) for i in range(num_replicas)] + + # sort in descending order + indices = np.argsort(lengths)[::-1] + + for idx, size in zip(indices, lengths[indices]): + new_fill = heap[0].fill + size + if new_fill > max_len: + # Size doesn't fit in least full batch (or any others), can't satisfy requirements + return None + + if heap[0].rank == rank: + # minimum bucket corresponds to the local rank -> add idx to local batch + local_batch.append(start_index + idx) + + _ = heapreplace(heap, _Bin(new_fill, heap[0].rank)) + + return local_batch + + +def _assign_to_packed_batches( + lengths: np.ndarray, max_len: int, rank: int, replicas: int +) -> list[NDArray]: + """Distribute lengths to batches across all ranks, while respecting batch_max_length. Uses a binary search + LPT algorithm + + Args: + lengths (np.ndarray): array of dataset sample lengths + max_len (int): maximum allowed sum of lengths in batch + rank (int): local rank to collect batches for + replicas (int): world size to distribute batches to + + Returns: + tuple[list, int, int]: + - list of np.arrays containing the indices for each batch on the local rank + - sum of dataset lengths included (total sum of lengths in dataset minus any that were dropped at end of dataset) + - total token capacity if each batch maxed out batch_max_length + """ + + lengths_so_far = 0 + ind = 0 + result = [] + lengths_cumsum = np.cumsum(lengths) + + # binary search for max integer x such that the next x elements in shuffled lengths array can be packed into `num_replicas` batches + # Add the local rank's batch to `result` and repeat until end of dataset + while True: + if len(lengths) - ind < replicas: + # Not enough lengths left to pack into `num_replicas` batches + # Break and drop whatever lengths we have left + break + + # binary search in [1, 1 + upper bound for x) + left = 1 + right = 1 + np.searchsorted( + lengths_cumsum[ind:], lengths_so_far + max_len * replicas, "right" + ) + + batch = None + while right - left > 1 and right > replicas: + mid = (left + right) // 2 + batch = _lpt_packed_batch( + lengths[ind : ind + mid], max_len, replicas, ind, rank + ) + if batch is None: + right = mid + else: + left = mid + + if batch is None: + batch = _lpt_packed_batch( + lengths[ind : ind + left], max_len, replicas, ind, rank + ) + + ind += left + lengths_so_far = lengths_cumsum[ind - 1] + + # append only result for local rank (already filtered in lpt_packed_batch) + result.append(batch) + + return result + + +class MultipackDistributedBatchSamplerV2(Sampler): + def __init__( + self, + batch_max_length: int, + lengths: ArrayLike, + num_replicas: int, + rank: int, + seed: int = 0, + ): + """Efficient distributed packing sampler for linear attention style models + + Args: + batch_max_length (int): max number of tokens in a single batch per device + lengths (ArrayLike[int]): the lengths of each sample in the dataset + num_replicas (int): The number of replicas to split the dataset across. + rank (int): The local rank to collect batches for. + seed (int, optional): Seed for RNG, must be the same on all ranks. Defaults to 0. + """ + self.num_replicas = num_replicas + self.rank = rank + self.seed = seed + self.epoch = 0 + self.batch_max_length = batch_max_length + self.lengths = np.array(lengths) + + self.valid_indices = np.nonzero(self.lengths <= self.batch_max_length)[0] + if self.rank == 0 and len(self.valid_indices) < len(self.lengths): + msg = ( + f"Dropping {len(self.lengths) - len(self.valid_indices)}" + f"/{len(self.lengths)} samples longer than batch_max_length. " + "Ensure that the right max_batch_length is used during data processing." + ) + warnings.warn(msg) + + def __iter__(self): + batches = self._generate_batches(self.epoch) + return iter(batches) + + def __len__(self): + batches = self._generate_batches(self.epoch) + return len(batches) + + def set_epoch(self, epoch: int): + self.epoch = epoch + + @lru_cache(maxsize=1) + def _generate_batches(self, epoch: int) -> list[NDArray]: + """Generate batches for local rank + + Returns: + list[NDArray]: list of np.arrays containing the indices for each batch on the local rank + """ + + rng = np.random.default_rng(seed=self.seed + epoch) + indices = rng.permutation(self.valid_indices) + + batches = _assign_to_packed_batches( + self.lengths[indices], self.batch_max_length, self.rank, self.num_replicas + ) + + # The indices in batches are relative to the shuffled self.lengths[indices] + # Translate them so that they are instead relative to the overall unshuffled self.lengths array + batches = [indices[batch] for batch in batches] + + # Cache result + return batches diff --git a/src/speculators/train/eagle3/__init__.py b/src/speculators/train/eagle3/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/speculators/train/eagle3/attention.py b/src/speculators/train/eagle3/attention.py new file mode 100644 index 0000000..c114f7f --- /dev/null +++ b/src/speculators/train/eagle3/attention.py @@ -0,0 +1,164 @@ +import torch +from transformers.modeling_utils import AttentionInterface +from transformers.integrations.flex_attention import repeat_kv +from torch.nn.attention.flex_attention import flex_attention +from torch.nn.attention.flex_attention import or_masks, and_masks, BlockMask +from typing import Callable + + +def create_combined_mask_mod(lengths: torch.Tensor): + total_seq_len = lengths.sum().item() + document_ids = torch.repeat_interleave( + torch.arange(lengths.shape[0], device=lengths.device, dtype=torch.long), lengths + ).contiguous() + N = document_ids.shape[0] + + def causal_mask_mod(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + def document_mask_mod(b, h, q_idx, kv_idx): + return document_ids[q_idx] == document_ids[kv_idx % N] + + def diagonal_draft_mask_mod(b, h, q_idx, kv_idx): + return kv_idx % total_seq_len == q_idx + + return or_masks( + and_masks(causal_mask_mod, document_mask_mod), diagonal_draft_mask_mod + ) + + +def extend_mask_for_draft_tokens(block_mask): + """ + Extend the block mask to include new draft tokens. Concatenates a diagonal mask for the new draft tokens. + + Assumptions: + - block_mask BLOCK_SIZE := KV_BLOCK_SIZE == Q_BLOCK_SIZE + - The number of query values is the original total_seq_len (or equivalently the number of query blocks is the original total_seq_len // BLOCK_SIZE) + + i.e. if block_mask is: + [ + [ + [1, 0, 0], + [1, 1, 0], + [0, 0, 1], + ] + ] + the result will be: + [ + [ + [1, 0, 0, 1, 0, 0], + [1, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1], + ] + ] + and then callinga again will give: + [ + [ + [1, 0, 0, 1, 0, 0, 1, 0, 0], + [1, 1, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 1], + ] + ] + + """ + kv_num_blocks = block_mask.kv_num_blocks + # shape: [B, H, Q_LEN // BLOCK_SIZE] + + kv_indices = block_mask.kv_indices + # shape: [B, H, Q_LEN // BLOCK_SIZE, KV_LEN // BLOCK_SIZE] + b, h, q_blocks, kv_blocks = kv_indices.shape + + # extend kv indices if needed + kv_indices = torch.cat( + [kv_indices, kv_indices.new_zeros((b, h, q_blocks, q_blocks))], dim=-1 + ) + new_block_indices = torch.arange( + kv_blocks, + kv_blocks + q_blocks, + dtype=kv_indices.dtype, + device=kv_indices.device, + ).reshape(1, 1, q_blocks, 1) + kv_indices.scatter_( + dim=-1, index=kv_num_blocks.unsqueeze(-1), src=new_block_indices + ) + + kv_num_blocks = kv_num_blocks + 1 + + return BlockMask.from_kv_blocks( + kv_num_blocks, + kv_indices, + block_mask.full_kv_num_blocks, + block_mask.full_kv_indices, + mask_mod=block_mask.mask_mod, + ) + + +def block_mask_to_dense_attention_mask( + block_mask: BlockMask, device: torch.device, dtype: torch.dtype +): + attention_mask = torch.ones(block_mask.shape, device=device, dtype=dtype) + + for q_idx in range(attention_mask.shape[2]): + attention_mask[0, 0, q_idx, :] = block_mask.mask_mod( + torch.zeros(1, device=device, dtype=torch.long), + torch.zeros(1, device=device, dtype=torch.long), + torch.ones(1, device=device, dtype=torch.long) * q_idx, + torch.arange(attention_mask.shape[3], device=device, dtype=torch.long), + ) + return attention_mask + + +def flex_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask, + scaling: float | None = None, + softcap: float | None = None, + head_mask: torch.Tensor | None = None, + s_aux: torch.Tensor | None = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor | None]: + block_mask = attention_mask + enable_gqa = False + + num_local_query_heads = query.shape[1] + # When running TP this helps: + if (num_local_query_heads & (num_local_query_heads - 1)) != 0: + key = repeat_kv(key, query.shape[1] // key.shape[1]) + value = repeat_kv(value, query.shape[1] // value.shape[1]) + + return_lse = query.device.type != "cpu" + + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + flex_attention_output = flex_attention( + query, + key, + value, + score_mod=None, + block_mask=block_mask, + enable_gqa=enable_gqa, + scale=scaling, + kernel_options=None, + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. + return_lse=return_lse, + ) + # lse is returned in float32 + if return_lse: + attention_output, lse = flex_attention_output # type: ignore[misc] + lse = lse.to(value.dtype) + else: + attention_output = flex_attention_output # type: ignore[assignment] + lse = None + + attention_output = attention_output.transpose(1, 2).contiguous() + return attention_output, lse + + +ALL_ATTENTION_FUNCTIONS = AttentionInterface() # Singleton class used for registry +ALL_ATTENTION_FUNCTIONS.register("simple_flex_attention", flex_attention_forward) diff --git a/src/speculators/train/eagle3/core.py b/src/speculators/train/eagle3/core.py new file mode 100644 index 0000000..aa59eee --- /dev/null +++ b/src/speculators/train/eagle3/core.py @@ -0,0 +1,231 @@ +import torch +from transformers.configuration_utils import PretrainedConfig +from transformers import AutoModelForCausalLM, DynamicCache + +from torch.nn.attention.flex_attention import create_block_mask + +from speculators.train.eagle3.attention import ( + create_combined_mask_mod, + extend_mask_for_draft_tokens, +) +from speculators.train.eagle3.model_definitions import model_classes + +def load_verifier_embeddings(verifier_model_name_or_path: str): + verifier_model = AutoModelForCausalLM.from_pretrained( + verifier_model_name_or_path + ) + return verifier_model.model.embed_tokens.state_dict() + + +class Eagle3VerifierLMHead(torch.nn.Module): + def __init__(self, hidden_size: int, draft_vocab_size: int): + super().__init__() + self.lm_head = torch.nn.Linear(hidden_size, draft_vocab_size, bias=False) + self.lm_head.weight.requires_grad = False + + def load_verifier_lm_head( + self, verifier_model_name_or_path: str, t2d_vocab: torch.Tensor + ): + verifier_model = AutoModelForCausalLM.from_pretrained( + verifier_model_name_or_path + ) + verifier_lm_head_data = verifier_model.lm_head.weight.data.to(t2d_vocab.device) + trucated_data = verifier_lm_head_data[t2d_vocab, :] + if trucated_data.shape[0] != self.lm_head.weight.shape[0]: + raise ValueError( + f"Truncated verifier lm head data shape {trucated_data.shape} does not match draft lm head shape {self.lm_head.weight.shape}" + ) + self.lm_head.weight.data = trucated_data + + @torch.no_grad() + def forward(self, verifier_last_hidden_states: torch.Tensor): + return self.lm_head(verifier_last_hidden_states) + + +class Eagle3DraftModel(torch.nn.Module): + def __init__( + self, + verifier_model_name_or_path: str, + hidden_size: int, # Must be same for verifier and draft + # Vocab mappings + t2d_vocab: torch.Tensor, + d2t_vocab: torch.Tensor, + decoder_layer_config: PretrainedConfig, + # Verifier + verifier_vocab_size: int, + verifier_pad_token_id: int | None, + # Draft config + num_layers: int = 1, + ttt_steps: int = 3, + ): + super().__init__() + self.verifier_model_name_or_path = verifier_model_name_or_path + self.hidden_size = hidden_size + self.num_layers = num_layers + self.decoder_layer_config = decoder_layer_config + self.ttt_steps = ttt_steps + self.register_buffer( + "t2d_vocab", t2d_vocab + ) # shape: [verifier_vocab_size], bool + self.register_buffer( + "d2t_vocab", d2t_vocab + ) # shape: [draft_vocab_size], int offsets + self.draft_vocab_size = t2d_vocab.sum(dtype=torch.long).item() + model_definitions = model_classes[decoder_layer_config.model_type] + + self.fc_layer = torch.nn.Linear(3 * hidden_size, hidden_size) + self.layers = torch.nn.ModuleList( + [ + model_definitions.decoder_layer_class(decoder_layer_config, layer_idx) + for layer_idx in range(num_layers) + ] + ) + self.norm = model_definitions.norm_class( + hidden_size, eps=decoder_layer_config.rms_norm_eps + ) + self.rotary_emb = model_definitions.rotary_emb_class(decoder_layer_config) + self.embed_tokens = torch.nn.Embedding( + verifier_vocab_size, hidden_size, padding_idx=verifier_pad_token_id + ) + # shape: [verifier_vocab_size, hidden_size] + self.embed_tokens.load_state_dict(load_verifier_embeddings(verifier_model_name_or_path)) + + self.lm_head = torch.nn.Linear(hidden_size, self.draft_vocab_size, bias=False) + # shape: [hidden_size, draft_vocab_size] + + def loss_function( + self, + logits: torch.Tensor, # shape: [batch_size=1, total_seq_len, draft_vocab_size] + targets: torch.Tensor, # shape: [batch_size=1, total_seq_len, draft_vocab_size] + loss_mask: torch.Tensor, # shape: [batch_size=1, total_seq_len] + ttt_step: int, + ): + # We don't have target values for the last ttt_step tokens, so we mask them out on the logit side + # We shift the target values by ttt_step + 1 to the left because that's the position the generated tokens correspond to + # e.g. + # targets_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9] + # logits_indices_ttt_step_0 = [1, 2, 3, 4, 5, 6, 7, 8, 9] + # logits_indices_ttt_step_1 = [2, 3, 4, 5, 6, 7, 8, 9, 10] + # logits_indices_ttt_step_2 = [3, 4, 5, 6, 7, 8, 9, 10, 11] + # The indices for the loss_mask need to be kept in line with the targets indices + + # Note: this function is written such that a batch_size > 1 is supported. This is through careful handling of the 0-th "batch dimension" + # However, currently the 0-th "batch dimension" is always 1 because we are packing the samples together by extending the sequence (1st) dimension. + # There could be a future use case for batch_size > 1, if pad and stack samples together instead of packing them. + logits = logits[:, :-ttt_step] if ttt_step > 0 else logits + targets = targets[:, ttt_step:] + # logits/targets shape: [batch_size=1, total_seq_len - ttt_step, draft_vocab_size] + loss_mask = loss_mask[:, ttt_step:] + # loss_mask shape: [batch_size=1, total_seq_len - ttt_step] + + logits = torch.nn.functional.log_softmax(logits, dim=-1) + targets = torch.nn.functional.log_softmax(targets, dim=-1) + kl_div = torch.nn.functional.kl_div( + logits, targets, reduction="none", log_target=True + ) + masked_kl_div = torch.sum(loss_mask.unsqueeze(-1) * kl_div, dim=(1, 2)) / ( + loss_mask.sum(dim=1) + 1e-5 + ) + # shape: [batch_size=1] + return masked_kl_div.mean() + + def forward( + self, + hidden_states: torch.Tensor, # shape: [1, total_seq_len, 3 * hidden_size] + input_ids: torch.Tensor, # shape: [1, total_seq_len] + lengths: torch.Tensor | None = None, # shape: [batch_size] + loss_mask: torch.Tensor | None = None, # shape: [1, total_seq_len] + position_ids: torch.Tensor | None = None, # shape: [1, total_seq_len] + target_logits: torch.Tensor + | None = None, # shape: [1, total_seq_len, draft_vocab_size] + ttt_steps: int | None = None, + use_off_policy_tokens: bool = False, + **kwargs, + ): + device = hidden_states.device + total_seq_len = hidden_states.shape[1] + return_loss = target_logits is not None + + if ttt_steps is None: + ttt_steps = self.ttt_steps + if lengths is None: + lengths = torch.tensor([total_seq_len], dtype=torch.long, device=device) + + past_key_values = DynamicCache(config=self.decoder_layer_config) + + combined_mask_mod = create_combined_mask_mod(lengths.to(device)) + block_mask = create_block_mask( + combined_mask_mod, + B=None, + H=None, + Q_LEN=total_seq_len, + KV_LEN=total_seq_len, + device=device, + ) + + hidden_states = self.fc_layer(hidden_states) + # shape: [1, total_seq_len, hidden_size] + + if return_loss: + loss = torch.tensor(0.0, device=device) + + original_input_ids = input_ids.detach().clone() + + draft_tokens = [] + for ttt_step in range(ttt_steps): + input_embeds = self.embed_tokens(input_ids) + # shape: [1, total_seq_len, hidden_size] + cache_position = torch.arange( + ttt_step * total_seq_len, + (ttt_step + 1) * total_seq_len, + dtype=torch.long, + device=device, + ) + # shape: [total_seq_len] + + hidden_states = torch.cat([input_embeds, hidden_states], dim=-1) + # shape: [1, total_seq_len, 2 * hidden_size] + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=block_mask, # block_mask_to_dense_attention_mask(block_mask, device, torch.bool), + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + logits = self.lm_head(hidden_states) + # shape: [1, total_seq_len, draft_vocab_size] + + if return_loss: + loss += self.loss_function(logits, target_logits, loss_mask, ttt_step) + + input_ids = torch.argmax(logits, dim=-1) + draft_tokens.append(input_ids.detach().clone()) + # shape: [1, total_seq_len] + # Use d2t to map draft tokens to verifier tokens. + # Must be in verifier vocabulary space because we use full verifier vocabulary in embedding + input_ids = input_ids + self.d2t_vocab[input_ids] + + if use_off_policy_tokens: + # Overwrite input_ids with ground truth tokens + # shift input_ids by 1 to the left and pad with 0 + # note: inputs_ids will no longer line up with verifier_last_hidden_state + # the draft logits generated from the padded tokens are ignored sliced out for loss calculation + input_ids = torch.cat( + [original_input_ids[:, 1:], original_input_ids.new_zeros(1, 1)], dim=-1 + ) + # shape: [1, total_seq_len] + + block_mask = extend_mask_for_draft_tokens(block_mask) + position_ids = position_ids + 1 + # shape: [1, total_seq_len] + + return draft_tokens, loss if return_loss else None diff --git a/src/speculators/train/eagle3/model_definitions.py b/src/speculators/train/eagle3/model_definitions.py new file mode 100644 index 0000000..67a9e1f --- /dev/null +++ b/src/speculators/train/eagle3/model_definitions.py @@ -0,0 +1,104 @@ +from typing import NamedTuple, Optional +import torch +from transformers.configuration_utils import PretrainedConfig +import copy +from transformers import Cache, LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) + +from transformers.processing_utils import Unpack +from transformers.utils.generic import TransformersKwargs + + +class LlamaConcatInputDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: PretrainedConfig, layer_idx: int): + super().__init__(config, layer_idx) + + ##### CHANGES START ##### + if layer_idx == 0: + self.self_attn.q_proj = torch.nn.Linear( + 2 * config.hidden_size, # previous: config.hidden_size + config.num_attention_heads * config.head_dim, + bias=config.attention_bias, + ) + self.self_attn.k_proj = torch.nn.Linear( + 2 * config.hidden_size, # previous: config.hidden_size + config.num_key_value_heads * config.head_dim, + bias=config.attention_bias, + ) + self.self_attn.v_proj = torch.nn.Linear( + 2 * config.hidden_size, # previous: config.hidden_size + config.num_key_value_heads * config.head_dim, + bias=config.attention_bias, + ) + self.input_layernorm = LlamaRMSNorm( + 2 * config.hidden_size, # previous: config.hidden_size + eps=config.rms_norm_eps, + ) + self.layer_idx = layer_idx + ##### CHANGES END ##### + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + ##### CHANGES START ##### + # previous: residual = hidden_states + if self.layer_idx == 0: + residual = hidden_states[:, :, hidden_states.shape[2] // 2 :] + else: + residual = hidden_states + ##### CHANGES END ##### + + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class LlamaConcatInputRotaryEmbedding(LlamaRotaryEmbedding): + def __init__(self, config: LlamaConfig, device=None): + config = copy.copy(config) + config.hidden_size = config.hidden_size * 2 + super().__init__(config, device) + + +class ModelComponents(NamedTuple): + decoder_layer_class: type + norm_class: type + rotary_emb_class: type + + +model_classes: dict[str, ModelComponents] = { + "llama": ModelComponents( + LlamaConcatInputDecoderLayer, LlamaRMSNorm, LlamaConcatInputRotaryEmbedding + ), +} diff --git a/src/speculators/train/logger.py b/src/speculators/train/logger.py new file mode 100644 index 0000000..be94401 --- /dev/null +++ b/src/speculators/train/logger.py @@ -0,0 +1,549 @@ +"""Logging utilities for the Speculators training module. + +This module provides a logging system for training machine learning models, +supporting multiple logging backends including TensorBoard (tensorboard), Weights & Biases (wandb). + +Example Usage: + ```python + from speculators.train.logger import setup_metric_logger + + # Setup logging with TensorBoard and wandb + setup_metric_logger( + loggers=["tensorboard", "wandb"], + run_name="my_training_run", + output_dir="logs" + ) + + # Log metrics + import logging + logger = logging.getLogger("speculators.metrics") + + # Log a simple metric + logger.info({"loss": 0.5, "accuracy": 0.95}, extra={"step": 100}) + + # Log nested metrics + logger.info({ + "training": { + "loss": 0.5, + "accuracy": 0.95 + }, + "validation": { + "loss": 0.6, + "accuracy": 0.92 + } + }, extra={"step": 100}) + + # Log hyperparameters + logger.info({ + "learning_rate": 0.001, + "batch_size": 32, + "model": { + "hidden_size": 512, + "num_layers": 6 + } + }, extra={"hparams": True}) + ``` +""" + +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from collections.abc import Mapping +from datetime import datetime, timezone +from logging.config import dictConfig +import importlib +from pathlib import Path +from typing import Any, Union +import logging +import os +import warnings + + +# Third Party +from rich.logging import RichHandler +import torch + +### Helper functions + +LogDict = Mapping[str, Union[str, int, float, "LogDict"]] + + +def _substitute_placeholders( + run_name: str | None, default_template: str = "{time}" +) -> str: + """Replace placeholders in the run name with actual values. + + This function supports dynamic run name generation by replacing placeholders + with actual values from the environment or current time. This is particularly + useful for distributed training scenarios where you want unique run names + for each process. + + Supported placeholders: + - {time}: Current local timestamp in ISO format + - {utc_time}: Current UTC timestamp in ISO format + - {rank}: Process rank from RANK environment variable + - {local_rank}: Local process rank from LOCAL_RANK environment variable + + Args: + run_name: String containing placeholders to be replaced. If None, uses default_template + default_template: Default template to use if run_name is None + + Returns: + String with all placeholders replaced by their values + + Example: + ```python + # With default template + name = _substitute_placeholders(None) + # Result: "2024-03-14T10:30:00_rank0" + + # With custom template + name = _substitute_placeholders("experiment_{time}_rank{rank}") + # Result: "experiment_2024-03-14T10:30:00_rank0" + ``` + """ + if run_name is None: + run_name = default_template + + substitutions = { + "{time}": datetime.now().isoformat(timespec="seconds"), + "{utc_time}": datetime.now(timezone.utc).isoformat(timespec="seconds"), + "{rank}": os.environ.get("RANK", 0), + "{local_rank}": os.environ.get("LOCAL_RANK", 0), + } + for placeholder_pat, value in substitutions.items(): + run_name = run_name.replace(placeholder_pat, str(value)) + + return run_name + + +def _flatten_dict(log_dict: LogDict, sep: str = "/", prefix: str = "") -> dict: + """Flatten a nested dictionary into a single-level dictionary. + + This function recursively traverses a nested dictionary and creates a new + dictionary with keys that represent the path to each value in the original + dictionary. + + Args: + d: The dictionary to flatten + sep: Separator to use between nested keys + prefix: Prefix to add to all keys + + Returns: + A flattened dictionary with keys joined by the separator + """ + flattened = {} + + for k, v in log_dict.items(): + if isinstance(v, Mapping): + flattened |= _flatten_dict(v, sep=sep, prefix=f"{prefix}{k}{sep}") + else: + flattened[prefix + k] = v + + return flattened + + +### Filters +class IsMappingFilter(logging.Filter): + """Filter that only allows log records with dictionary messages. + + This filter ensures that only log records containing dictionary messages + are processed by the handler. This is useful for metric logging where + we want to ensure all logged messages are structured data. + """ + + def filter(self, record): + """Check if the log record's message is a dictionary. + + Args: + record: The log record to check + + Returns: + bool: True if the message is a dictionary, False otherwise + """ + return isinstance(record.msg, Mapping) + + +class IsRank0Filter(logging.Filter): + """Filter that only allows log records from rank 0 in distributed training. + + This filter is useful in distributed training scenarios where you want to + ensure that only the main process (rank 0) logs metrics to avoid duplicate + logging. The rank can be determined from various sources in order of precedence: + 1. Explicitly provided rank value + 2. Record's rank attribute + 3. Record's message dictionary + 4. Environment variables + 5. PyTorch distributed rank + + Args: + rank_val: Optional explicit rank value to use + local_rank: If True, use local_rank instead of global rank + """ + + def __init__(self, rank_val: int | None = None, local_rank: bool = False): + self.rank_val = rank_val + if local_rank: + self.rank_attr = "local_rank" + else: + self.rank_attr = "rank" + + def _get_rank(self, record): + rank = ( + self.rank_val + or getattr(record, self.rank_attr, None) + or (isinstance(record.msg, Mapping) and record.msg.get(self.rank_attr)) + or os.environ.get(self.rank_attr.upper(), None) + or ( + self.rank_attr == "rank" + and torch.distributed.is_initialized() + and torch.distributed.get_rank() + ) + or 0 + ) + + return int(rank) + + def filter(self, record): + return self._get_rank(record) == 0 + + +class FormatDictFilter(logging.Filter): + """Reformats dictionary messages for prettier printing. + + This filter processes dictionary messages to create a more readable string + representation. It handles different types of values appropriately: + - Floats are formatted with 3 decimal places or scientific notation + - Integers are formatted as decimal numbers + - Other types are converted to their string representation + + Note: This is not a true filter, but a processing step as described in the + Python logging cookbook: https://docs.python.org/3/howto/logging-cookbook.html#using-filters-to-impart-contextual-information + """ + + @staticmethod + def _format_value(v): + if isinstance(v, float): + if abs(v) < 0.001 or abs(v) > 999: + return f"{v:.2e}" + return f"{v:.3f}" + elif isinstance(v, int): + return f"{v:d}" + else: + return repr(v) + + def filter(self, record): + if not isinstance(record.msg, Mapping): + return True + flat_dict = _flatten_dict(record.msg) + + record.msg = ", ".join( + f"{k}={self._format_value(v)}" for k, v in flat_dict.items() + ) + + return True + + +### Handlers +class TensorBoardHandler(logging.Handler): + """Logger that writes metrics to TensorBoard. + + This handler expects a (nested) dictionary of metrics or text to be logged with string keys. + A step can be specified by passing `extra={"step": }` to the logging method. + To log hyperparameters, pass a (nested) mapping of hyperparameters to the logging method + and set `extra={"hparams": True}`. + """ + + def __init__( + self, + level: int = logging.INFO, + run_name: str | None = None, + log_dir: str | os.PathLike = "logs", + **tboard_init_kwargs: Any, + ): + """Initialize the TensorBoard logger and check for required dependencies. + + Args: + level: The logging level for this handler + run_name: Name of the run, can contain placeholders + log_dir: Directory where TensorBoard logs should be stored + """ + super().__init__(level) + + self.tboard_init_kwargs = tboard_init_kwargs.copy() + self.tboard_init_kwargs.setdefault( + "log_dir", Path(log_dir) / _substitute_placeholders(run_name) + ) + + self._tboard_writer = None + + def _setup(self): + """Create the TensorBoard log directory and initialize the writer. + + Raises: + RuntimeError: If tensorboard package is not installed + """ + + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError as e: + msg = ( + "Could not initialize TensorBoardHandler because package tensorboard could not be imported.\n" + "Please ensure it is installed by running 'pip install tensorboard' or configure the logger to use a different backend." + ) + raise RuntimeError(msg) from e + + os.makedirs(self.tboard_init_kwargs["log_dir"], exist_ok=True) + self._tboard_writer = SummaryWriter(**self.tboard_init_kwargs) + + def emit(self, record: logging.LogRecord): + """Emit a log record to TensorBoard. + + This method handles both scalar metrics and text logs, automatically + detecting the type of data being logged. + + Args: + record: The log record to emit + """ + if self._tboard_writer is None: + self._setup() + + if not isinstance(record.msg, Mapping): + warnings.warn( + f"TensorBoardHandler expected a mapping, got {type(record.msg)}. Skipping log. Please ensure the handler is configured correctly to filter out non-mapping objects." + ) + return + + flat_dict = _flatten_dict(record.msg) + step = getattr(record, "step", None) + if getattr(record, "hparams", None): + self._tboard_writer.add_hparams( + flat_dict, {}, run_name=".", global_step=step + ) + return + + for k, v in flat_dict.items(): + try: + # Check that `v` can be converted to float + float(v) + except ValueError: + # Occurs for strings that cannot be converted to floats (e.g. "3.2.3") and aren't "inf" or "nan" + self._tboard_writer.add_text(k, v, global_step=step) + except TypeError: + warnings.warn( + f"TensorBoardHandler expected a scalar or text, got {type(v)}. Skipping log. Please ensure metric logger is only called with mappings containing scalar values or text." + ) + else: + self._tboard_writer.add_scalar(k, v, global_step=step) + + def flush(self): + """Flush the TensorBoard writer.""" + if self._tboard_writer is not None: + self._tboard_writer.flush() + + def close(self): + """Close the TensorBoard writer and cleanup resources.""" + if self._tboard_writer is not None: + self._tboard_writer.close() + self._tboard_writer = None + super().close() + + +class WandbHandler(logging.Handler): + """Logger that sends metrics to Weights & Biases (wandb). + + This handler expects a (nested) dictionary of metrics or text to be logged with string keys. + A step can be specified by passing `extra={"step": }` to the logging method. + To log hyperparameters, pass a (nested) mapping of hyperparameters to the logging method + and set `extra={"hparams": True}`. + """ + + def __init__( + self, + level: int = logging.INFO, + run_name: str | None = None, + log_dir: str | os.PathLike = "logs", + **init_kwargs: Any, + ): + """Initialize the wandb logger and check for required dependencies. + + Args: + level: The logging level for this handler + run_name: Name of the run, can contain placeholders + log_dir: Directory where wandb logs should be stored + """ + super().__init__(level) + + self.init_kwargs = init_kwargs.copy() + self.init_kwargs.setdefault("dir", Path(log_dir)) + self.init_kwargs.setdefault("name", _substitute_placeholders(run_name)) + self.init_kwargs.setdefault("config", {}) + + self._package_name = "wandb" + self._run = None + + def _setup(self): + try: + wandb = importlib.import_module(self._package_name) + except ImportError as e: + msg = ( + f"Could not initialize {self.__class__.__name__} because package {self._package_name} could not be imported.\n" + f"Please ensure it is installed by running 'pip install {self._package_name}' or configure the logger to use a different backend." + ) + raise RuntimeError(msg) from e + + self._run = wandb.init(**self.init_kwargs) + + def emit(self, record: logging.LogRecord): + if self._run is None: + self._setup() + + if not isinstance(record.msg, Mapping): + warnings.warn( + f"{self.__class__.__name__} expected a mapping, got {type(record.msg)}. Skipping log. Please ensure the handler is configured correctly to filter out non-mapping objects." + ) + return + + flat_dict = _flatten_dict(record.msg) + step = getattr(record, "step", None) + if getattr(record, "hparams", None): + for k, v in flat_dict.items(): + self._run.config[k] = v + return + + self._run.log(flat_dict, step=step) + + +class TrackioHandler(WandbHandler): + """Logger that sends metrics to Trackio. + + This handler expects a (nested) dictionary of metrics or text to be logged with string keys. + A step can be specified by passing `extra={"step": }` to the logging method. + To log hyperparameters, pass a (nested) mapping of hyperparameters to the logging method + and set `extra={"hparams": True}`. + """ + + def __init__( + self, + level: int = logging.INFO, + run_name: str | None = None, + log_dir: str | os.PathLike = "logs", + **init_kwargs: Any, + ): + """Initialize the trackio logger and check for required dependencies. + + Args: + level: The logging level for this handler + run_name: Name of the run, can contain placeholders + """ + super().__init__(level) + + self.init_kwargs = init_kwargs.copy() + self.init_kwargs.setdefault("name", _substitute_placeholders(run_name)) + self.init_kwargs.setdefault("config", {}) + self.init_kwargs.setdefault("project", "speculators") + + # Trackio doesn't support the dir keyword argument so we ignore log_dir + + self._package_name = "trackio" + self._run = None + + +### Main functions + + +def setup_root_logger(level="INFO"): + """Configure the root logger with rich formatting. + + This function sets up the root logger with a RichHandler for + console output and adds the FormatDictFilter for better dictionary message + formatting. + """ + handler = RichHandler() + handler.addFilter(FormatDictFilter()) + handler.addFilter(IsRank0Filter(local_rank=True)) + logging.basicConfig( + level=level, format="%(message)s", datefmt="[%X]", handlers=[handler] + ) + + +def setup_metric_logger(loggers, run_name, output_dir): + """Configure the metric logging system with specified backends. + + This function sets up a comprehensive logging configuration that supports + multiple logging backends simultaneously. It configures filters, handlers, + and loggers for structured metric logging. + + Args: + loggers: A string or list of strings specifying which logging backends to use. + Supported values: "tensorboard", "wandb", "trackio" + run_name: Name for the current training run. Can include placeholders like + {time}, {rank}, {utc_time}, {local_rank}. + output_dir: Directory where log files will be stored + + Example: + ```python + # Setup logging with multiple backends + setup_metric_logger( + loggers=["tensorboard", "wandb", "trackio"], + run_name="experiment_{time}", + output_dir="logs" + ) + + # Setup logging with a single backend + setup_metric_logger( + loggers="tensorboard", + run_name="my_run", + output_dir="logs" + ) + ``` + """ + if isinstance(loggers, str): + loggers = loggers.split(",") + loggers = [logger.strip() for logger in loggers] + + logging_config = { + "version": 1, + "disable_existing_loggers": False, + "filters": { + "is_mapping": { + "()": IsMappingFilter, + }, + "is_rank0": { + "()": IsRank0Filter, + }, + }, + "handlers": { + "tensorboard": { + "()": TensorBoardHandler, + "log_dir": output_dir, + "run_name": run_name, + "filters": ["is_mapping", "is_rank0"], + }, + "wandb": { + "()": WandbHandler, + "log_dir": output_dir, + "run_name": run_name, + "filters": ["is_mapping", "is_rank0"], + }, + "trackio": { + "()": TrackioHandler, + "log_dir": output_dir, + "run_name": run_name, + "filters": ["is_mapping", "is_rank0"], + }, + }, + "loggers": { + "speculators.metrics": { + "handlers": loggers, + "filters": ["is_mapping"], + "level": "INFO", + "propagate": True, + }, + "speculators": { + "level": "INFO", + "propagate": True, + }, + }, + } + dictConfig(logging_config) diff --git a/src/speculators/train/trainer.py b/src/speculators/train/trainer.py new file mode 100644 index 0000000..12220f8 --- /dev/null +++ b/src/speculators/train/trainer.py @@ -0,0 +1,219 @@ +import torch +from torch.distributed.fsdp import FSDPModule, fully_shard, MixedPrecisionPolicy +from torch.utils.data import DataLoader +from tqdm.rich import tqdm # todo: requries tqdm and rich + + +import torch.distributed as dist +import logging + +from speculators.train.checkpointer import ( + SingleGPUCheckpointer, + DistributedCheckpointer, +) + +root_logger = logging.getLogger("speculators") +metric_logger = logging.getLogger("speculators.metrics") + + +@torch.no_grad() +def compute_draft_accuracy( + target_logits: torch.Tensor, + draft_tokens: list[torch.Tensor], + loss_mask: torch.Tensor | None = None, +): + accuracies = [] + target_tokens = torch.argmax(target_logits, dim=-1) + # shape: [batch_size, total_seq_len] + + for step, drafts in enumerate(draft_tokens): + correct = target_tokens[:, step:] == (drafts[:, : -step] if step > 0 else drafts) + if loss_mask is not None: + correct = correct[:, loss_mask[:, step:]] + accuracies.append(correct.float().mean()) + + return torch.tensor(accuracies, device=target_logits.device) + + +def apply_fully_sharded(model: torch.nn.Module): + fsdp_kwargs = { + "mp_policy": MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + } + + for layer in model.layers: # todo: this is hardcoded to the Eagle3DraftModel definition, should be made more general + # we apply fully_shard to each DecoderLayer + layer.to_empty(device="meta") + fully_shard(layer, **fsdp_kwargs) + + fully_shard(model, **fsdp_kwargs) + + return model + + +class Trainer: + def __init__( + self, + model: torch.nn.Module, + verifier_lm_head: torch.nn.Module, + config: dict, + train_loader: DataLoader, + val_loader: DataLoader | None = None, + is_distributed: bool = False, + local_rank: int = 0, + world_size: int = 1, + ): + self.model = model + self.verifier_lm_head = verifier_lm_head + self.config = config + self.train_loader = train_loader + self.val_loader = val_loader + self.is_distributed = is_distributed + self.local_rank = local_rank + self.world_size = world_size + checkpointer_class = ( + DistributedCheckpointer if is_distributed else SingleGPUCheckpointer + ) + self.checkpointer = checkpointer_class( + config["save_path"], + try_load_last_checkpoint=config.get("resume_from_checkpoint", False), + ) + + self.setup_trainer() + self.setup_model() + self.setup_optimizer() + + def setup_trainer(self): + self.current_epoch = self.checkpointer.previous_epoch + 1 + self.global_step = 0 + + def setup_model(self): + if self.is_distributed: + apply_fully_sharded(self.model) + + if self.checkpointer.previous_epoch != -1: + self.checkpointer.load_model_state_dict(self.model) + else: + for m in self.model.layers.children(): # todo: generalize + if not isinstance(m, FSDPModule): + continue + m.to_empty(device="cuda") # todo: generalize + for sub_module in m.modules(): + if hasattr(sub_module, "reset_parameters"): + sub_module.reset_parameters() + # todo: We need to make sure we're loading lm_head and embed_tokens after this reset + else: + self.model.to(self.local_rank) + self.verifier_lm_head = self.verifier_lm_head.to(self.local_rank) + + def setup_optimizer(self): + self.opt = torch.optim.Adam(self.model.parameters(), lr=self.config["lr"]) + if self.checkpointer.previous_epoch != -1: + self.checkpointer.load_optimizer_state_dict(self.model, self.opt) + + def train_epoch(self, epoch: int): + self.model.train() + self.train_loader.batch_sampler.set_epoch( + epoch + ) # todo: check if this is safe to call + + if self.local_rank == 0: + train_loader = tqdm(self.train_loader, desc=f"Epoch {epoch}") + else: + train_loader = self.train_loader + root_logger.info(f"Training Epoch {epoch} started") + + for batch in train_loader: + batch = { + k: v.to(self.local_rank) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() + } + target_logits = self.verifier_lm_head(batch["verifier_last_hidden_states"]) + del batch["verifier_last_hidden_states"] + + draft_tokens, loss = self.model( + **batch, target_logits=target_logits, use_off_policy_tokens=True + ) # set this in a better way + + self.opt.zero_grad() + loss.backward() + self.opt.step() + + draft_accuracies = compute_draft_accuracy( + target_logits, draft_tokens, batch["loss_mask"] + ) + + loss = loss.detach().clone() + if self.is_distributed: + # Note: this is not needed for training, just for logging + dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG) + dist.reduce(draft_accuracies, dst=0, op=dist.ReduceOp.AVG) + + acc_values = { + f"acc_{i}": acc.item() for i, acc in enumerate(draft_accuracies) + } + metric_logger.info( + {"train": {"loss": loss.item(), **acc_values}, "epoch": epoch}, + extra={"step": self.global_step}, + ) + self.global_step += 1 + + root_logger.info(f"Training Epoch {epoch} completed") + + def val_epoch(self, epoch: int): + if self.val_loader is None: + root_logger.warning("No val loader, skipping validation") + return + self.model.eval() + self.val_loader.batch_sampler.set_epoch(epoch) + + if self.local_rank == 0: + val_loader = tqdm(self.val_loader, desc=f"Epoch {epoch}") + else: + val_loader = self.val_loader + + root_logger.info(f"Validation Epoch {epoch} started") + + for i, batch in enumerate(val_loader): + batch = { + k: v.to(self.local_rank) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() + } + target_logits = self.verifier_lm_head(batch["verifier_last_hidden_states"]) + del batch["verifier_last_hidden_states"] + + draft_tokens, val_loss = self.model( + **batch, target_logits=target_logits, use_off_policy_tokens=True + ) # set this in a better way + + draft_accuracies = compute_draft_accuracy( + target_logits, draft_tokens, batch["loss_mask"] + ) + + if self.is_distributed: + dist.reduce(val_loss, dst=0, op=dist.ReduceOp.AVG) + dist.reduce(draft_accuracies, dst=0, op=dist.ReduceOp.AVG) + + # todo: Accumulate these values across the epoch and then log at the end of the epoch + acc_values = { + f"acc_{i}": acc.item() for i, acc in enumerate(draft_accuracies) + } + metric_logger.info( + {"val": {"loss": val_loss.item(), **acc_values}}, extra={"step": i} + ) + + root_logger.info(f"Validation Epoch {epoch} completed") + + def save_checkpoint(self, epoch: int): + self.checkpointer.save_checkpoint(self.model, self.opt, epoch) + root_logger.info(f"Checkpoint saved to {self.checkpointer.path / str(epoch)}") + + def run_training(self): + for epoch in range(self.current_epoch, self.config["num_epochs"]): + self.train_epoch(epoch) + if self.is_distributed: + dist.barrier() + self.val_epoch(epoch) + self.save_checkpoint(epoch) diff --git a/src/speculators/train/utils.py b/src/speculators/train/utils.py new file mode 100644 index 0000000..ec3dbb9 --- /dev/null +++ b/src/speculators/train/utils.py @@ -0,0 +1,40 @@ +import os +import torch +import torch.distributed as dist + +local_rank = int(os.environ.get("LOCAL_RANK", 0)) + + +def maybe_setup_distributed(): + # Based off of https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html#initialize-ddp-with-torch-distributed-run-torchrun + if "LOCAL_RANK" not in os.environ: + # No distributed training + return 0, 1, 0, False + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + + torch.accelerator.set_device_index(local_rank) + acc = torch.accelerator.current_accelerator() + backend = torch.distributed.get_default_backend_for_device(acc) + dist.init_process_group(backend) + + rank = dist.get_rank() + + print( + f"Started DDP with local_rank={local_rank}, world_size={world_size}, rank={rank}" + ) + return local_rank, world_size, rank, True + + +def maybe_destroy_distributed(): + if "LOCAL_RANK" not in os.environ: + # No distributed training + return + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + rank = dist.get_rank() + + dist.destroy_process_group() + print( + f"Destroyed DDP with local_rank={local_rank}, world_size={world_size}, rank={rank}" + ) diff --git a/tests/unit/train/test_eagle3_attention.py b/tests/unit/train/test_eagle3_attention.py new file mode 100644 index 0000000..61fbc16 --- /dev/null +++ b/tests/unit/train/test_eagle3_attention.py @@ -0,0 +1,130 @@ +import pytest +from speculators.train.eagle3.attention import ( + create_combined_mask_mod, + extend_mask_for_draft_tokens, + flex_attention_forward, +) +from torch.nn.attention.flex_attention import BlockMask +import torch + + +def test_create_combined_mask_mod(): + lengths = torch.tensor([1, 2, 3]) + mask_mod = create_combined_mask_mod(lengths) + + # Creates causal document mask mod that supports extended diagonals + # lengths -> document ids [0, 1, 1, 2, 2, 2] + # Expected mask mod values for q_idx (row), kv_idx (column): + expected_mask_mod = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 1, 1, 1], + ] + t0 = torch.tensor(0) + + for q_idx in range(len(expected_mask_mod)): + for kv_idx in range(len(expected_mask_mod[q_idx])): + assert mask_mod(t0, t0, q_idx, kv_idx) == expected_mask_mod[q_idx][kv_idx] + + +@pytest.mark.parametrize( + "lengths", [torch.tensor([1, 2, 3]), torch.tensor([2, 2, 2]), torch.tensor([5])] +) +def test_diagonal_draft_tokens_mask_mod(lengths): + # Causal Diagonal + # ⌄ ⌄ ⌄ | ⌄ ⌄ ⌄ ⌄ ⌄ ⌄ + # 1 0 0 | 1 0 0 1 0 0 + # 1 1 0 | 0 1 0 0 1 0 + # 1 1 1 | 0 0 1 0 0 1 + # If kv_idx > N (N = original seq len = num query indices), only the diagonal tokens are in mask + # Diagonal tokens are those where kv_idx % N == q_idx + + mask_mod = create_combined_mask_mod(lengths) + + N = lengths.sum().item() + + t0 = torch.tensor(0) + for q_idx in range(N): + for kv_idx in range(N, 3 * N): + assert mask_mod(t0, t0, q_idx, kv_idx) == (kv_idx % N == q_idx) + + +@pytest.mark.parametrize( + "kv_num_blocks, kv_indices, expected_kv_indices", + [ + # Test 1: Dense matrix shown in comments in test code + ( + torch.tensor([2, 2, 1]), + torch.tensor([[0, 2, -1], [0, 1, -1], [1, -1, -1]]), + torch.tensor([[0, 2, 3], [0, 1, 4], [1, 5, -1]]), + ), + # Test 2: Dense matrix below + # 0 1 1 0 + # 1 0 1 1 + # 1 0 0 1 + # 1 1 1 1 + ( + torch.tensor([2, 3, 2, 4]), + torch.tensor([[1, 2, -1, -1], [0, 2, 3, -1], [0, 3, -1, -1], [0, 1, 2, 3]]), + torch.tensor( + [ + [1, 2, 4, -1, -1], + [0, 2, 3, 5, -1], + [0, 3, 6, -1, -1], + [0, 1, 2, 3, 7], + ] + ), + ), + ], +) +def test_extend_mask_for_draft_tokens(kv_num_blocks, kv_indices, expected_kv_indices): + # Block mask is stored in Block Compressed Sparse Row (BSRS) format + # This means storing: + # - kv_num_blocks (shape: [batch, head, q_blocks]): contains the number of blocks for each batch, head, and query block + # - kv_indices (shape: [batch, head, q_blocks, kv_blocks]): contains the row indices of the blocks for each batch, head, and query block + # Only the first kv_num_blocks of each row of kv_indices are defined + # e.g. To store (ignoring batch and head dimensions): + # 1 0 1 + # 1 1 0 + # 0 1 0 + # There are 2 blocks for the first query row (0, 2), 2 blocks for the second query row (0, 1), and 1 block for the third query row (1) + # Therefore: + # kv_num_blocks = [2, 2, 1] + # kv_indices = [[[0, 2, U], [0, 1, U], [1, U, U]]] where U indicates the value is undefined + # Note: for our masks currently batch and head indices aren't considered in the mask function, so we just treat them as 1 when storing the BlockMask + + # During ttt, we extend the mask to accomodate the new draft tokens. The tokens included will be those on the diagonal (see diagonal test above), + # and therefore we need to include blocks on the newly added diagonal. + + # Therefore, we expect `kv_num_blocks` to increase by 1 for each query row because only the diagonal block will be added to each row. + # We also expect `kv_indices` to include the new diagonal blocks for each query row. + + kv_num_blocks = kv_num_blocks.reshape(1, 1, *kv_num_blocks.shape) + kv_indices = kv_indices.reshape(1, 1, *kv_indices.shape) + expected_kv_indices = expected_kv_indices.reshape(1, 1, *expected_kv_indices.shape) + + def dummy_mask_mod(b, h, q_idx, kv_idx): + return True + + block_mask = BlockMask.from_kv_blocks( + kv_num_blocks=kv_num_blocks.clone(), + kv_indices=kv_indices.clone(), + mask_mod=dummy_mask_mod, + ) + + extended_mask = extend_mask_for_draft_tokens(block_mask) + + for q_idx in range(kv_num_blocks.shape[2]): + num_defined_blocks_in_row = extended_mask.kv_num_blocks[0, 0, q_idx].item() + # Only the first num_defined_blocks_in_row of each row of kv_indices are defined, the rest can have any value + # Check that the defined blocks are match expected values + assert torch.equal( + extended_mask.kv_indices[0, 0, q_idx, :num_defined_blocks_in_row], + expected_kv_indices[0, 0, q_idx, :num_defined_blocks_in_row], + ) + + assert extended_mask.mask_mod == block_mask.mask_mod +