Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/speculators/convert/eagle/eagle3_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def _build_eagle3_speculator_config(
return Eagle3SpeculatorConfig(
transformer_layer_config=transformer_config,
speculators_config=speculators_config,
draft_vocab_size=eagle_config.get("draft_vocab_size", 32000),
norm_before_residual=norm_before_residual,
target_hidden_size=eagle_config.get("target_hidden_size"),
)

def _create_transformer_config_from_eagle(
Expand Down
111 changes: 63 additions & 48 deletions src/speculators/models/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
from typing import Any, ClassVar, Literal, Optional, Union

import torch
from pydantic import Field, field_serializer, field_validator, model_validator
from pydantic import (
BaseModel,
Field,
field_serializer,
field_validator,
model_validator,
)
from torch import nn
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
Expand All @@ -32,11 +38,66 @@
__all__ = [
"EagleSpeculator",
"EagleSpeculatorConfig",
"TransformerLayerConfigMixin",
]


class TransformerLayerConfigMixin(BaseModel):
transformer_layer_config: PretrainedConfig = Field(
default_factory=LlamaConfig,
description=(
"Configuration object for the transformer layer architecture. "
"Must be a PretrainedConfig instance that matches the requirements "
"of the transformer_layer_architecture. Contains parameters such as "
"hidden_size, num_attention_heads, intermediate_size, vocab_size, "
"and other architecture-specific settings. "
"Additionally, it contains all the necessary information to check and "
"validate compatibility between the speculator and verifier models, "
"such as the vocab_size used for the speculator and the hidden_size "
"used for the speculator's transformer layer, which must match "
"the verifier's hidden_size according to the algorithm's design."
),
)

@field_serializer("transformer_layer_config")
def serialize_transformer_layer_config(self, value: PretrainedConfig) -> dict:
"""
Serialize the transformer_layer_config to a dictionary for JSON storage.

Converts the PretrainedConfig object to its dictionary representation
using to_diff_dict() to only include non-default values.

:param value: The PretrainedConfig instance to serialize
:return: Dictionary representation of the transformer layer configuration
"""
return value.to_diff_dict()

@field_validator("transformer_layer_config", mode="before")
@classmethod
def validate_transformer_layer_config(cls, value: Any) -> PretrainedConfig:
"""
Validate and convert transformer_layer_config to a PretrainedConfig instance.

Accepts either a dictionary that can be converted to a PretrainedConfig
or an existing PretrainedConfig instance.

:param value: The value to validate (dict or PretrainedConfig)
:return: A validated PretrainedConfig instance
:raises ValueError: If the value cannot be converted to a PretrainedConfig
"""
if isinstance(value, dict):
return AutoConfig.for_model(**value)
if isinstance(value, PretrainedConfig):
return value

raise ValueError(
"transformer_layer_config must be a PretrainedConfig instance or a "
"dictionary that can be converted to a PretrainedConfig."
)


@SpeculatorModelConfig.register("eagle")
class EagleSpeculatorConfig(SpeculatorModelConfig):
class EagleSpeculatorConfig(SpeculatorModelConfig, TransformerLayerConfigMixin):
"""
A SpeculatorModelConfig implementation to be used with the EagleSpeculator
for EAGLE and HASS variants for spec decoding:
Expand Down Expand Up @@ -91,16 +152,6 @@ class EagleSpeculatorConfig(SpeculatorModelConfig):
"transformer decoder layer class (e.g., 'LlamaDecoderLayer')."
),
)
transformer_layer_config: PretrainedConfig = Field(
default_factory=LlamaConfig,
description=(
"Configuration object for the transformer layer architecture. "
"Must be a PretrainedConfig instance that matches the requirements "
"of the transformer_layer_architecture. Contains parameters such as "
"hidden_size, num_attention_heads, intermediate_size, vocab_size, "
"and other architecture-specific settings."
),
)
layernorms: bool = Field(
default=False,
description=(
Expand Down Expand Up @@ -140,42 +191,6 @@ def check_add_architectures(self) -> Self:

return self

@field_serializer("transformer_layer_config")
def serialize_transformer_layer_config(self, value: PretrainedConfig) -> dict:
"""
Serialize the transformer_layer_config to a dictionary for JSON storage.

Converts the PretrainedConfig object to its dictionary representation
using to_diff_dict() to only include non-default values.

:param value: The PretrainedConfig instance to serialize
:return: Dictionary representation of the transformer layer configuration
"""
return value.to_diff_dict()

@field_validator("transformer_layer_config", mode="before")
@classmethod
def validate_transformer_layer_config(cls, value: Any) -> PretrainedConfig:
"""
Validate and convert transformer_layer_config to a PretrainedConfig instance.

Accepts either a dictionary that can be converted to a PretrainedConfig
or an existing PretrainedConfig instance.

:param value: The value to validate (dict or PretrainedConfig)
:return: A validated PretrainedConfig instance
:raises ValueError: If the value cannot be converted to a PretrainedConfig
"""
if isinstance(value, dict):
return AutoConfig.for_model(**value)
if isinstance(value, PretrainedConfig):
return value

raise ValueError(
"transformer_layer_config must be a PretrainedConfig instance or a "
"dictionary that can be converted to a PretrainedConfig."
)


@SpeculatorModel.register("eagle")
class EagleSpeculator(SpeculatorModel):
Expand Down
Loading
Loading