Skip to content
157 changes: 77 additions & 80 deletions src/speculators/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,22 @@

This module provides a command-line interface for creating and managing speculative
decoding models. The CLI is built using Typer and provides commands for model
conversion, version information, and other utilities.

The CLI can be accessed through the `speculators` command after installation, or by
running this module directly with `python -m speculators`.

Commands:
convert: Convert models from external repos/formats to supported Speculators models
version: Display the current version of the Speculators library

Usage:
$ speculators --help
$ speculators --version
$ speculators convert <model> [OPTIONS]
conversion, version information, and other utilities. It serves as the primary
entry point for users to interact with the Speculators library from the command line.

Example:
::
speculators --help
speculators convert "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" \
--algorithm eagle \
--verifier "meta-llama/Llama-3.1-8B-Instruct"
"""

from __future__ import annotations

import json
from importlib.metadata import version as pkg_version
from typing import Annotated, Any, Optional
from typing import Annotated, Any, Literal, cast

import click
import typer # type: ignore[import-not-found]
Expand All @@ -37,16 +35,11 @@
)


def version_callback(value: bool):
def version_callback(value: bool) -> None:
"""
Callback function to print the version of the Speculators package and exit.

This function is used as a callback for the --version option in the main CLI.
When the version option is specified, it prints the version information and
exits the application.
Print the Speculators package version and exit.

:param value: Boolean indicating whether the version option was specified.
If True, prints version and exits.
:param value: Whether the version option was specified
"""
if value:
typer.echo(f"speculators version: {pkg_version('speculators')}")
Expand All @@ -65,12 +58,8 @@ def speculators(
"""
Main entry point for the Speculators CLI application.

This function serves as the root command callback and handles global options
such as version display. It is automatically called by Typer when the CLI
is invoked.

:param ctx: The Typer context object containing runtime information.
:param version: Boolean option to display version information and exit.
:param ctx: Typer context object containing runtime information
:param version: Option to display version information and exit
"""


Expand All @@ -79,40 +68,41 @@ def convert(
model: Annotated[
str, typer.Argument(help="Model checkpoint or Hugging Face model ID to convert")
],
output_path: Annotated[
str, typer.Option(help="Directory path where converted model will be saved")
] = "converted",
config: str | None = None,
verifier: Annotated[
str,
str | None,
typer.Option(
"--verifier",
help=(
"Verifier model checkpoint or Hugging Face model ID "
"to attach as the verification/base model for speculative decoding"
),
),
],
algorithm: Annotated[
str,
typer.Option(
help=(
"The source repo/algorithm to convert from into the matching algorithm "
"in Speculators"
),
click_type=click.Choice(["eagle", "eagle3"]),
),
],
output_path: Annotated[
str, typer.Option(help="Directory path where converted model will be saved")
] = "converted",
] = None,
validate_device: Annotated[
Optional[str],
str | None,
typer.Option(
help=(
"Device to validate the model on (e.g. 'cuda:0') "
"If not provided, validation is skipped."
),
),
] = None,
algorithm: Annotated[
str,
typer.Option(
help=(
"The source repo/algorithm to convert from into the matching algorithm "
"in Speculators"
),
click_type=click.Choice(["auto", "eagle", "eagle2", "hass"]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, is leaving out eagle3 intentional?

),
] = "auto",
algorithm_kwargs: Annotated[
Optional[dict[str, Any]],
dict[str, Any] | None,
typer.Option(
parser=json.loads,
help=(
Expand All @@ -122,52 +112,59 @@ def convert(
),
),
] = None,
):
cache_dir: str | None = None,
force_download: bool = False,
local_files_only: bool = False,
token: str | None = None,
revision: str | None = None,
) -> None:
"""
Convert models from external research repositories or formats
into the standardized Speculators format for use within the Speculators
framework, Hugging Face model hub compatability, and deployment with vLLM.
Supported algorithms, repositories, and examples given below.

\b
algorithm=="eagle":
Eagle v1, v2: https://github.com/SafeAILab/EAGLE
HASS: https://github.com/HArmonizedSS/HASS
::
# general
speculators convert "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" \\
--algorithm eagle \\
Convert models from external research repositories into Speculators format.

Converts models from research implementations (EAGLE, HASS) into standardized
Speculators format for use with Hugging Face, vLLM, and the Speculators framework.

[EAGLE v1, v2](https://github.com/SafeAILab/EAGLE),
and [HASS](https://github.com/HArmonizedSS/HASS) Example:
::
speculators convert "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" \
--verifier "meta-llama/Llama-3.1-8B-Instruct"

# with layernorms and fusion bias enabled
speculators convert "./eagle/checkpoint" \\
--algorithm eagle \\
--algorithm-kwargs '{"layernorms": true, "fusion_bias": true}' \\
speculators convert "./eagle/checkpoint" \
--algorithm-kwargs '{"layernorms": true, "fusion_bias": true}' \
--verifier "meta-llama/Llama-3.1-8B-Instruct"

\b
algorithm=="eagle3":
Eagle v3: https://github.com/SafeAILab/EAGLE
::
# general
speculators convert "./eagle/checkpoint" \\
--algorithm eagle3
# eagle3 with normalization before the residual
--algorithm-kwargs '{"norm_before_residual": true}' \
--verifier "meta-llama/Llama-3.1-8B-Instruct"
# with normalization before the residual
speculators convert "./eagle/checkpoint" \\
--algorithm eagle3
--algorithm-kwargs '{"norm_before_residual": true}'
--verifier "meta-llama/Llama-3.1-8B-Instruct"
"""
if not algorithm_kwargs:
algorithm_kwargs = {}

:param model: Model checkpoint path or Hugging Face model ID to convert
:param output_path: Directory path where converted model will be saved
:param config: Optional config path, model ID, or config instance
:param verifier: Optional verifier model for speculative decoding
:param validate_device: Optional device for post-conversion validation
:param algorithm: Source algorithm to convert from (auto, eagle, eagle2, hass)
:param algorithm_kwargs: Additional conversion algorithm keyword arguments
:param cache_dir: Optional directory for caching downloaded model files
:param force_download: Force re-downloading files even if cached
:param local_files_only: Use only local files without downloading from hub
:param token: Optional Hugging Face authentication token for private models
:param revision: Optional Git revision for downloading from Hugging Face hub
"""
convert_model(
model=model,
verifier=verifier,
output_path=output_path,
config=config,
verifier=verifier,
validate_device=validate_device,
algorithm=algorithm, # type: ignore[arg-type]
**algorithm_kwargs,
algorithm=cast('Literal["auto", "eagle", "eagle2", "hass"]', algorithm),
algorithm_kwargs=algorithm_kwargs or {},
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
)


Expand Down
4 changes: 3 additions & 1 deletion src/speculators/convert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
- HASS: https://github.com/HArmonizedSS/HASS
"""

from .converters import SpeculatorConverter
from .eagle import Eagle3Converter, EagleConverter
from .entrypoints import convert_model

__all__ = ["convert_model"]
__all__ = ["Eagle3Converter", "EagleConverter", "SpeculatorConverter", "convert_model"]
14 changes: 14 additions & 0 deletions src/speculators/convert/converters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
Registry-based converter architecture for transforming external checkpoints.

This module provides the converter framework for standardizing external research model
checkpoints into the Speculators format. The converter system uses a registry pattern
to automatically detect and instantiate appropriate converters based on algorithm type
and model characteristics, supporting extensible conversion workflows with validation.
"""

from __future__ import annotations

from .base import SpeculatorConverter

__all__ = ["SpeculatorConverter"]
Loading