Skip to content

Commit 6dbf4e5

Browse files
committed
Initial pieces for eagle converter and unit tests
Signed-off-by: Mark Kurtz <[email protected]>
1 parent 4bab336 commit 6dbf4e5

File tree

3 files changed

+1083
-1
lines changed

3 files changed

+1083
-1
lines changed

src/speculators/convert/converters/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
from __future__ import annotations
1111

1212
from .base import SpeculatorConverter
13+
from .eagle import EagleSpeculatorConverter
1314

14-
__all__ = ["SpeculatorConverter"]
15+
__all__ = ["EagleSpeculatorConverter", "SpeculatorConverter"]
Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
"""
2+
Eagle/HASS checkpoint converter for Speculators model format.
3+
4+
This module provides the EagleSpeculatorConverter class for transforming Eagle-style
5+
speculative decoding checkpoints (including HASS variants) from research repositories
6+
into the standardized Speculators format. The converter handles automatic feature
7+
detection, weight remapping, configuration translation, and optional validation.
8+
9+
::
10+
from speculators.convert.converters import EagleSpeculatorConverter
11+
12+
# Convert with automatic feature detection
13+
converter = EagleSpeculatorConverter(
14+
model="path/to/eagle_checkpoint",
15+
config="path/to/config.json",
16+
verifier="meta-llama/Meta-Llama-3.1-8B-Instruct"
17+
)
18+
converted_model = converter(output_path="./output", validate_device="cuda")
19+
"""
20+
21+
from __future__ import annotations
22+
23+
import os
24+
from pathlib import Path
25+
from typing import Annotated, Literal
26+
27+
import torch
28+
from loguru import logger
29+
from torch import Tensor, nn
30+
from transformers import LlamaConfig, PretrainedConfig, PreTrainedModel
31+
32+
from speculators.config import SpeculatorsConfig, VerifierConfig
33+
from speculators.convert.converters.base import SpeculatorConverter
34+
from speculators.models.eagle import EagleSpeculator, EagleSpeculatorConfig
35+
from speculators.proposals.greedy import GreedyTokenProposalConfig
36+
from speculators.utils import (
37+
load_model_checkpoint_config_dict,
38+
load_model_checkpoint_state_dict,
39+
)
40+
41+
__all__ = ["EagleSpeculatorConverter"]
42+
43+
44+
@SpeculatorConverter.register(["eagle", "eagle2", "hass"])
45+
class EagleSpeculatorConverter(
46+
SpeculatorConverter[EagleSpeculatorConfig, EagleSpeculator]
47+
):
48+
"""
49+
Converter for Eagle/HASS research checkpoint format to Speculators format.
50+
51+
This converter transforms Eagle-style speculative decoding checkpoints into the
52+
standardized Speculators format, handling weight remapping, configuration
53+
translation, and feature detection. It supports both the original Eagle
54+
architecture and its variants including HASS, automatically detecting model
55+
features such as fusion bias and layernorms based on checkpoint structure.
56+
57+
Example:
58+
::
59+
from speculators.convert.converters import EagleSpeculatorConverter
60+
61+
converter = EagleSpeculatorConverter(
62+
model="path/to/eagle_checkpoint",
63+
config="path/to/config.json",
64+
verifier="meta-llama/Meta-Llama-3.1-8B-Instruct"
65+
)
66+
converted_model = converter(output_path="./output", validate_device="cuda")
67+
68+
:cvar WEIGHT_MAPPINGS: Parameter name mappings from Eagle to Speculators format
69+
:cvar LAYERNORM_MAPPINGS: LayerNorm parameter mappings from Eagle to Speculators
70+
"""
71+
72+
weight_mappings: Annotated[
73+
dict[str, str],
74+
"Parameter name mappings from Eagle checkpoint format to Speculators format",
75+
] = {"fc.": "fusion_fc.", "layers.0.": "transformer."}
76+
layernorm_mappings: Annotated[
77+
dict[str, str],
78+
"LayerNorm param mappings from Eagle checkpoint format to Speculators format",
79+
] = {
80+
"embed_layernorm.weight": "embedding_layernorm.weight",
81+
"hidden_layernorm.weight": "transformer.input_layernorm.weight",
82+
"lm_head_layernorm.weight": "pre_lm_head_layernorm.weight",
83+
}
84+
85+
@classmethod
86+
def is_supported(
87+
cls,
88+
model: Path | PreTrainedModel | nn.Module,
89+
config: Path | PretrainedConfig | dict, # noqa: ARG003
90+
verifier: str | os.PathLike | PreTrainedModel | None = None, # noqa: ARG003
91+
fusion_bias: bool | None = None, # noqa: ARG003
92+
layernorms: bool | None = None, # noqa: ARG003
93+
**kwargs, # noqa: ARG003
94+
) -> bool:
95+
"""
96+
Check if the provided model checkpoint is supported by this converter.
97+
98+
Validates that the model follows the Eagle architecture pattern by checking
99+
for the presence of fusion layer weights and single transformer layer
100+
structure.
101+
102+
:param model: Model checkpoint path or instance to validate
103+
:param config: Model configuration (unused for Eagle detection)
104+
:param verifier: Optional verifier model (unused for Eagle detection)
105+
:param fusion_bias: Optional fusion bias setting (unused for Eagle detection)
106+
:param layernorms: Optional layernorms setting (unused for Eagle detection)
107+
:param kwargs: Additional arguments (unused for Eagle detection)
108+
:return: True if the model follows Eagle architecture pattern
109+
"""
110+
state_dict = load_model_checkpoint_state_dict(model)
111+
has_fc = "fc.weight" in state_dict
112+
has_layers_0 = any(name.startswith("layers.0.") for name in state_dict)
113+
has_layers_non_0 = any(
114+
name.startswith("layers.") and not name.startswith("layers.0.")
115+
for name in state_dict
116+
)
117+
118+
return has_fc and has_layers_0 and not has_layers_non_0
119+
120+
def __init__(
121+
self,
122+
model: Path | PreTrainedModel | nn.Module,
123+
config: Path | PretrainedConfig | dict,
124+
verifier: str | os.PathLike | PreTrainedModel | None = None,
125+
fusion_bias: bool | None = None,
126+
layernorms: bool | None = None,
127+
):
128+
"""
129+
Initialize the Eagle converter with model, configuration, and feature
130+
settings.
131+
132+
:param model: Model checkpoint path or instance to convert
133+
:param config: Model configuration path or instance
134+
:param verifier: Optional verifier model path or instance for speculative
135+
decoding
136+
:param fusion_bias: Whether to include fusion bias in conversion. If None,
137+
automatically detected from checkpoint structure
138+
:param layernorms: Whether to include extra layernorms in conversion. If None,
139+
automatically detected from checkpoint structure
140+
"""
141+
super().__init__(
142+
model=model,
143+
config=config,
144+
verifier=verifier,
145+
)
146+
self.fusion_bias = fusion_bias
147+
self.layernorms = layernorms
148+
149+
def convert_config_state_dict(
150+
self,
151+
) -> tuple[EagleSpeculatorConfig, dict[str, Tensor]]:
152+
"""
153+
Convert Eagle/HASS checkpoint configuration and state dict to Speculators
154+
format.
155+
156+
Processes the original Eagle checkpoint by detecting features, remapping
157+
weights, and creating a compatible EagleSpeculatorConfig. Handles automatic
158+
detection of fusion bias and layernorms based on checkpoint structure.
159+
160+
:return: Tuple of converted configuration and remapped state dictionary
161+
"""
162+
logger.info(
163+
f"Converting Eagle/HASS checkpoint at model: {self.model} and "
164+
f"config: {self.config} to speculators format..."
165+
)
166+
orig_state_dict = load_model_checkpoint_state_dict(self.model)
167+
orig_config = load_model_checkpoint_config_dict(self.config)
168+
fusion_bias = (
169+
self.fusion_bias
170+
if self.fusion_bias is not None
171+
else "fc.bias" in orig_state_dict
172+
)
173+
layernorms = (
174+
self.layernorms
175+
if self.layernorms is not None
176+
else any(name in orig_state_dict for name in self.layernorm_mappings)
177+
)
178+
179+
converted_config = self._eagle_speculator_config(
180+
orig_config, fusion_bias, layernorms
181+
)
182+
logger.info(
183+
f"Converted Eagle/HASS config to speculators format: {converted_config}"
184+
)
185+
186+
converted_state_dict, extra = self._eagle_speculator_state_dict(
187+
orig_state_dict, fusion_bias, layernorms
188+
)
189+
logger.info(
190+
"Converted Eagle/HASS state_dict to speculators format: "
191+
f"{converted_state_dict.keys()}"
192+
)
193+
if extra:
194+
logger.warning(f"Extra keys in converted state_dict: {extra}")
195+
196+
return converted_config, converted_state_dict
197+
198+
def validate(self, model: EagleSpeculator, device: str | torch.device | int):
199+
"""
200+
Validate the converted model by running a forward pass with test data.
201+
202+
Ensures the converted EagleSpeculator model is correctly configured and can
203+
process inputs without errors. Uses conservative defaults for batch size and
204+
sequence length to minimize resource requirements.
205+
206+
:param model: The converted EagleSpeculator model to validate
207+
:param device: Device for validation (string, torch.device, or device index)
208+
:raises Exception: If validation forward pass fails
209+
"""
210+
logger.info("Validating converted checkpoint...")
211+
212+
try:
213+
config = model.config
214+
vocab_size = config.transformer_layer_config.vocab_size
215+
hidden_size = config.transformer_layer_config.hidden_size
216+
max_position_embeddings = (
217+
config.transformer_layer_config.max_position_embeddings
218+
)
219+
220+
# Use conservative defaults for batch size and sequence length
221+
batch_size = 1
222+
seq_length = min(16, max_position_embeddings) # Don't exceed max length
223+
224+
logger.debug(
225+
f"Running forward pass with batch_size={batch_size}, "
226+
f"seq_length={seq_length}, vocab_size={vocab_size}, "
227+
f"hidden_size={hidden_size}"
228+
)
229+
230+
model.to(device) # type: ignore[attr-defined,arg-type]
231+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)).to(
232+
device
233+
)
234+
hidden_states = torch.randn(batch_size, seq_length, hidden_size).to(device)
235+
with torch.no_grad():
236+
model(input_ids=input_ids, hidden_states=hidden_states) # type: ignore[operator]
237+
model.to("cpu") # type: ignore[attr-defined,arg-type]
238+
239+
logger.success("Validation forward pass successful")
240+
except Exception as exception:
241+
logger.error(f"Validation failed: {exception}")
242+
raise exception
243+
244+
def _pretrained_config_from_eagle(self, eagle_config: dict) -> LlamaConfig:
245+
return LlamaConfig(
246+
vocab_size=eagle_config.get("vocab_size", 32000),
247+
hidden_size=eagle_config.get("hidden_size", 4096),
248+
intermediate_size=eagle_config.get("intermediate_size", 11008),
249+
num_hidden_layers=1, # Eagle always uses a single decoder layer
250+
num_attention_heads=eagle_config.get("num_attention_heads", 32),
251+
num_key_value_heads=eagle_config.get("num_key_value_heads"),
252+
hidden_act=eagle_config.get("hidden_act", "silu"),
253+
max_position_embeddings=eagle_config.get("max_position_embeddings", 4096),
254+
initializer_range=eagle_config.get("initializer_range", 0.02),
255+
rms_norm_eps=eagle_config.get("rms_norm_eps", 1e-6),
256+
use_cache=eagle_config.get("use_cache", True),
257+
pad_token_id=eagle_config.get("pad_token_id"),
258+
bos_token_id=eagle_config.get("bos_token_id", 1),
259+
eos_token_id=eagle_config.get("eos_token_id", 2),
260+
tie_word_embeddings=False, # Eagle uses separate embed_tokens from verifier
261+
rope_theta=eagle_config.get("rope_theta", 10000.0),
262+
rope_scaling=eagle_config.get("rope_scaling"),
263+
attention_bias=eagle_config.get("attention_bias", False),
264+
attention_dropout=eagle_config.get("attention_dropout", 0.0),
265+
mlp_bias=eagle_config.get("mlp_bias", False),
266+
)
267+
268+
def _eagle_speculator_config(
269+
self,
270+
orig_config: dict,
271+
fusion_bias: bool,
272+
layernorms: bool,
273+
) -> EagleSpeculatorConfig:
274+
logger.debug(
275+
f"Building config with fusion_bias={fusion_bias}, layernorms={layernorms} "
276+
f"from Eagle checkpoint config: {orig_config}"
277+
)
278+
pretrained_config = self._pretrained_config_from_eagle(orig_config)
279+
280+
return EagleSpeculatorConfig(
281+
transformer_layer_config=pretrained_config,
282+
speculators_config=SpeculatorsConfig(
283+
algorithm="eagle",
284+
proposal_methods=[
285+
GreedyTokenProposalConfig(
286+
proposal_type="greedy",
287+
speculative_tokens=5,
288+
)
289+
],
290+
default_proposal_method="greedy",
291+
verifier=VerifierConfig.from_pretrained(
292+
self.verifier,
293+
),
294+
),
295+
layernorms=layernorms,
296+
fusion_bias=fusion_bias,
297+
)
298+
299+
def _classify_param_key(
300+
self, weight_name: str, fusion_bias: bool, layernorms: bool
301+
) -> Literal["keep", "ignore", "extra"]:
302+
if weight_name == "embed_tokens.weight":
303+
return "ignore"
304+
305+
if weight_name == "fc.bias":
306+
return "keep" if fusion_bias else "extra"
307+
308+
if weight_name in self.layernorm_mappings:
309+
return "keep" if layernorms else "extra"
310+
311+
return (
312+
"keep"
313+
if any(weight_name.startswith(prefix) for prefix in self.weight_mappings)
314+
else "extra"
315+
)
316+
317+
def _remap_param_name(self, param_name: str) -> str:
318+
mappings = {
319+
**self.weight_mappings,
320+
**self.layernorm_mappings,
321+
}
322+
for from_mapping, to_mapping in mappings.items():
323+
if param_name.startswith(from_mapping):
324+
return param_name.replace(from_mapping, to_mapping)
325+
326+
raise ValueError(
327+
f"Unexpected parameter name format: {param_name}. "
328+
"Please check the Eagle checkpoint structure."
329+
)
330+
331+
def _eagle_speculator_state_dict(
332+
self,
333+
orig_state_dict: dict[str, Tensor],
334+
fusion_bias: bool,
335+
layernorms: bool,
336+
) -> tuple[dict[str, Tensor], list[str]]:
337+
logger.debug(
338+
f"Processing state_dict with fusion_bias={fusion_bias}, "
339+
f"layernorms={layernorms} from original keys: {orig_state_dict.keys()}"
340+
)
341+
converted_state_dict = {}
342+
extra_keys = []
343+
344+
for name, tensor in orig_state_dict.items():
345+
param_key_action = self._classify_param_key(name, fusion_bias, layernorms)
346+
347+
if param_key_action == "ignore":
348+
continue
349+
350+
if param_key_action == "extra":
351+
extra_keys.append(name)
352+
continue
353+
354+
new_name = self._remap_param_name(name)
355+
converted_state_dict[new_name] = tensor
356+
357+
logger.debug(
358+
f"Converted state_dict with {list(converted_state_dict)} weights, "
359+
f"and {list(extra_keys)} extra keys."
360+
)
361+
362+
return converted_state_dict, extra_keys

0 commit comments

Comments
 (0)