Skip to content

Conversation

markurtz
Copy link
Collaborator

Summary

Refactors the registry pattern from ClassRegistryMixin to a generic RegistryMixin, introduces comprehensive type annotations, and improves API consistency. Lays the foundation for a more extensible class-based converter architecture by providing better type safety, cleaner and expanded APIs, and enhanced auto-discovery capabilities.

Details

  • general:
    • Added eval_type_backport dependency for better type evaluation support
    • Updated and added type annotations using from __future__ import annotations
  • registry.py:
    • Renamed ClassRegistryMixin to RegistryMixin with generic type support RegistryMixin[RegistryObjT]
    • Introduced new generic type variables: RegistryObjTRegisterTBaseModelTRegisterClassT
    • Enhanced registry methods with registered_objects()is_registered(), and get_registered_object()
    • Updated so that typing systems will evaluate wrapped classes correctly
  • pydantic_utils.py:
    • Updated PydanticClassRegistryMixin to use generic types: PydanticClassRegistryMixin[BaseModelT]
    • Improved polymorphic validation with automatic schema reloading
    • Enhanced error handling and type validation for registered classes
    • Added registered_classes() method for better Pydantic-specific functionality
    • Updated so that typing systems will evaluate wrapped classes correctly
  • auto_importer.py:
    • Streamlined documentation and improved type annotations
    • Better error handling and validation for package discovery
    • Enhanced module import tracking and duplicate prevention
  • model.py and config.py:
    • Updated to use the changes from registry.py and pydantic_utils.py

Test Plan

  • Expanded and added unit tests to for the full functionality

Related Issues

N/A

@markurtz markurtz self-assigned this Aug 29, 2025
Copy link

github-actions bot commented Aug 29, 2025

📦 Build Artifacts Available
The build artifacts (`.whl` and `.tar.gz`) have been successfully generated and are available for download: https://github.com/vllm-project/speculators/actions/runs/18129843369/artifacts/4143080806.
They will be retained for up to 30 days.
Commit: a39ad50

@markurtz markurtz requested a review from Copilot August 30, 2025 02:27
Copilot

This comment was marked as outdated.

@dsikka dsikka requested a review from shanjiaz September 2, 2025 15:05
@markurtz markurtz force-pushed the features/converters/utils-updates branch from 4403a77 to d7ec38f Compare September 4, 2025 22:40
@markurtz markurtz requested a review from Copilot September 4, 2025 22:48
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the registry pattern from ClassRegistryMixin to a generic RegistryMixin, introduces comprehensive type annotations, and improves API consistency. It lays the foundation for a more extensible class-based converter architecture by providing better type safety, cleaner and expanded APIs, and enhanced auto-discovery capabilities.

Key changes include:

  • Renamed ClassRegistryMixin to RegistryMixin with generic type support
  • Enhanced registry methods with registered_objects(), is_registered(), and get_registered_object()
  • Improved polymorphic validation with automatic schema reloading in Pydantic utilities
  • Updated type annotations using from __future__ import annotations

Reviewed Changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tests/unit/utils/test_registry.py Comprehensive test suite for refactored RegistryMixin with new methods and type validation
tests/unit/utils/test_pydantic_utils.py Updated tests for PydanticClassRegistryMixin with enhanced type safety and generic support
tests/unit/utils/test_auto_importer.py Restructured tests for AutoImporterMixin with improved fixtures and edge case coverage
tests/unit/test_model.py Minor type comment additions for registry attribute access
tests/unit/models/test_eagle_model.py Minor type comment additions for registry attribute access
tests/unit/models/test_eagle_config.py Minor type comment additions for registry attribute access
src/speculators/utils/registry.py Core refactor from ClassRegistryMixin to generic RegistryMixin with enhanced API
src/speculators/utils/pydantic_utils.py Enhanced Pydantic integration with generic type support and schema reloading
src/speculators/utils/auto_importer.py Streamlined documentation and improved type annotations
src/speculators/utils/init.py Updated imports to reflect registry class rename
src/speculators/model.py Updated to use new RegistryMixin with generic type parameter
src/speculators/convert/eagle/init.py Minor documentation update
pyproject.toml Added eval_type_backport dependency for type evaluation support

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@markurtz markurtz changed the base branch from features/converters/main to main September 8, 2025 22:35
@markurtz markurtz force-pushed the features/converters/utils-updates branch 3 times, most recently from 99e7ff3 to 4f47c07 Compare September 8, 2025 23:26
@markurtz markurtz force-pushed the features/converters/utils-updates branch from 4f47c07 to c67c12f Compare September 8, 2025 23:26
rahul-tuli
rahul-tuli previously approved these changes Sep 10, 2025
@fynnsu
Copy link
Collaborator

fynnsu commented Sep 11, 2025

Could you explain the use case for ReloadableBaseModel? Maybe with a concrete example?

This seems fairly complicating to iterate through all of descendants of BaseModel and update them and I'm not sure I fully understand what we're trying to accomplish with this.

@markurtz
Copy link
Collaborator Author

markurtz commented Sep 12, 2025

@fynnsu (cc @shanjiaz), here’s the rationale behind expanding the ReloadableBaseModel.

TL;DR

  • We need a robust, extensible way to serialize/deserialize polymorphic models in Speculators.
  • Naive generics → work but require explicit typing everywhere.
  • Registry + Discriminated Unions → solves resolution but unions are static (can’t extend externally).
  • Automatic Discriminated Unions → dynamic but requires manual model_rebuild calls in correct nested order.
  • ReloadableBaseModel → automates recursive reloads, ensuring stable schemas without developer intervention.
  • Relies only on stable Python/Pydantic APIs, so low long-term risk if well-tested.

The high-level requirement is that we need to represent polymorphic types for the instance types/fields that extend our base classes and interfaces. These base classes/interfaces define a standard format in Speculators and ensure consistent, serializable/deserializable objects. We rely on Pydantic’s BaseModel to minimize boilerplate around serialization, deserialization, and validation. Any core classes, nested classes, polymorphic types, or more complex objects should extend BaseModel at a minimum. Current core examples for the polymorphic types are TokenProposalConfig and SpeculatorModelConfig, with VerifierConfig potentially following.

1. Naive Implementation

The simplest approach explicitly references classes. To make Pydantic handle custom reloads, we use generics to tell it which class is involved. This forces prior knowledge of what’s being constructed and requires referencing the specific types everywhere. It works, but only recovers fields when typed explicitly, which makes it brittle for real use.

from typing import Generic, TypeVar
from pydantic import BaseModel

class BaseExample(BaseModel):
    shared_field: str = ""

class NestedExample(BaseExample):
    extra_field: int = 1

NestedT = TypeVar("NestedT", bound=BaseExample)

class ParentExample(BaseModel, Generic[NestedT]):
    nested: NestedT

def test_marshalling():
    nested = NestedExample()
    parent = ParentExample(nested=nested)
    parent_dict = parent.model_dump()
    parent_rec_notype = ParentExample.model_validate(parent_dict)
    print(parent_rec_notype.nested)  # only BaseExample fields
    parent_rec_typed = ParentExample[NestedExample].model_validate(parent_dict)
    print(parent_rec_typed.nested)  # includes NestedExample fields

if __name__ == "__main__":
    test_marshalling()

2. Registry + Discriminated Unions

The naive case fails because:

  1. No automatic class resolution (so from_pretrained won’t work).
  2. No automatic deserialization resolution (so Class.model_validate(dict) won’t work).

We can solve this with a registry pattern plus Pydantic’s discriminated unions. The registry enables runtime class lookup and construction. But discriminated unions are statically defined meaning external libraries, contrib packages, or prototypes can’t extend them without editing the source to add the new type and developers need to know about and update all container fields that add the polymorphic class. Which are hard limitations.

from typing import Literal, Union
from pydantic import BaseModel, Field, ValidationError
from speculators.utils import RegistryMixin

class BaseExample(BaseModel, RegistryMixin):
    @classmethod
    def create(cls, type_: str, **kwargs) -> "BaseExample":
        example_class = cls.registry[type_]
        return example_class(**kwargs)

    type_: Literal["base"] = "base"
    shared_field: str = ""

@BaseExample.register()
class NestedExample(BaseExample):
    type_: Literal["nested_internal"] = "nested_internal"
    extra_field: int = 1

class ParentExample(BaseModel):
    nested: Union[NestedExample, BaseExample] = Field(discriminator="type_")

@BaseExample.register()
class ExternalNestedExample(BaseExample):
    type_: Literal["nested_external"] = "nested_external"
    different_field: bool = True

def test_marshalling():
    nested = BaseExample.create(type_="NestedExample")
    print(type(nested))  # NestedExample
    parent = ParentExample(nested=nested)
    parent_dict = parent.model_dump()
    parent_rec = ParentExample.model_validate(parent_dict)
    print(parent_rec.nested)  # NestedExample fields

    external = BaseExample.create(type_="ExternalNestedExample")
    print(type(external))  # ExternalNestedExample
    try:
        ParentExample(nested=external)
    except ValidationError as error:
        print(error)  # fails, union not updated at runtime

if __name__ == "__main__":
    test_marshalling()

3. Automatic Discriminated Unions

The registry + discriminated unions case fails when:

  1. The code needs to be externally modified, extended, or adapted.
  2. Risk of circular imports when the container and base live in the same module and children are defined externally, since the parent must reference subclasses while subclasses also need to register with the base.

To unlock external extension and remove risks of circular imports, we need a dynamic discriminated union. This can be achieved by:

  • Declaring a schema_discriminator.
  • Overriding __get_pydantic_core_schema__ to only build discriminated unions for the base class.
  • Reloading the schema when a new type is registered.

This is implemented in Speculators via the PydanticClassRegistryMixin, which combines Pydantic’s BaseModel, a typed registry, and the auto-union functionality.

The catch: Pydantic caches schemas. Unless we manually call model_rebuild on any container classes, the new types won’t be recognized during validation. That puts the burden on implementors (internal or external) to remember where and in what order to reload models.

from typing import ClassVar, Literal
from pydantic import BaseModel, Field, ValidationError
from speculators.utils import PydanticClassRegistryMixin

# Note, need to replace ReloadableBaseModule for BaseModule in 
# PydanticClassRegistryMixin, otherwise it will run the later auto pathway
class BaseExample(PydanticClassRegistryMixin):
    @classmethod
    def create(cls, type_: str, **kwargs) -> "BaseExample":
        example_class = cls.registry[type_]
        return example_class(**kwargs)

    @classmethod
    def __pydantic_schema_base_type__(cls) -> type["BaseExample"]:
        if cls.__name__ == "BaseExample":
            return cls
        return BaseExample

    schema_discriminator: ClassVar[str] = "type_"

    type_: Literal["base"] = "base"
    shared_field: str = ""

@BaseExample.register("nested_internal")
class NestedExample(BaseExample):
    type_: Literal["nested_internal"] = "nested_internal"
    extra_field: int = 1

class ParentExample(BaseModel):
    nested: BaseExample = Field()

@BaseExample.register("nested_external")
class ExternalNestedExample(BaseExample):
    type_: Literal["nested_external"] = "nested_external"
    different_field: bool = True

def test_marshalling():
    nested = BaseExample.create(type_="nested_internal")
    parent = ParentExample(nested=nested)
    parent_dict = parent.model_dump()
    parent_rec = ParentExample.model_validate(parent_dict)
    print(parent_rec.nested)  # NestedExample fields

    external = BaseExample.create(type_="nested_external")
    try:
        ParentExample(nested=external)
    except ValidationError:
        print("Validation fails without reload")

    # Fix with reload
    ParentExample.model_rebuild(force=True)
    parent_ext = ParentExample(nested=external)
    parent_ext_rec = ParentExample.model_validate(parent_ext.model_dump())
    print(parent_ext_rec.nested)  # ExternalNestedExample fields

if __name__ == "__main__":
    test_marshalling()

4. Automatic Reloadable Discriminated Unions

The above approach requires calling model_rebuild in the correct nested order: inner → outer. For example, after adding a new TokenProposalConfig, we’d need to reload SpeculatorsConfig first and then SpeculatorModelConfig. Missing or mis-ordering reload calls leads to hard-to-debug cases where explicitly created objects work but can’t be deserialized back.

The ReloadableBaseModel automates this. It traverses class relationships and recursively reloads Pydantic models whenever a new subclass is registered. This guarantees stable schemas without manual intervention.

from typing import ClassVar, Literal
from pydantic import BaseModel, Field
from speculators.utils import PydanticClassRegistryMixin

class BaseExample(PydanticClassRegistryMixin):
    @classmethod
    def create(cls, type_: str, **kwargs) -> "BaseExample":
        example_class = cls.registry[type_]
        return example_class(**kwargs)

    @classmethod
    def __pydantic_schema_base_type__(cls) -> type["BaseExample"]:
        if cls.__name__ == "BaseExample":
            return cls
        return BaseExample

    schema_discriminator: ClassVar[str] = "type_"

    type_: Literal["base"] = "base"
    shared_field: str = ""

@BaseExample.register("nested_internal")
class NestedExample(BaseExample):
    type_: Literal["nested_internal"] = "nested_internal"
    extra_field: int = 1

class ParentExample(BaseModel):
    nested: BaseExample = Field()

@BaseExample.register("nested_external")
class ExternalNestedExample(BaseExample):
    type_: Literal["nested_external"] = "nested_external"
    different_field: bool = True

def test_marshalling():
    nested = BaseExample.create(type_="nested_internal")
    parent = ParentExample(nested=nested)
    print(parent.nested)  # NestedExample fields

    external = BaseExample.create(type_="nested_external")
    parent_ext = ParentExample(nested=external)
    print(parent_ext.nested)  # ExternalNestedExample fields

if __name__ == "__main__":
    test_marshalling()

Risks

The reload logic is non-trivial since it traverses nested relationships, inspects Pydantic fields, and recursively reloads. With strong tests, though, risk is limited to one utility class that is a core extension only over the top of Python and Pydantic depending on core, stable APIs:

  • Python
    • __subclasses__ for class discovery
    • isinstance / issubclass for type checks
    • typing utilities for generics inspection
  • Pydantic
    • model_rebuild to force schema refresh
    • model_json_schema to validate schema stability
    • model_fields to inspect nested fields

These are mature, stable APIs that rarely change except on major library updates.

Potential Improvements

Right now we traverse all BaseModel subclasses. We could simplify this by requiring all classes in Speculators to extend from ReloadableBaseModel instead, reducing the search space. Either way, we need to rebuild the relationship tree at runtime to preserve dynamic extensibility.

@fynnsu
Copy link
Collaborator

fynnsu commented Sep 12, 2025

@markurtz Thank you for the detailed walkthrough, I feel like I've learned a lot about pydantic just reading this!

You're explanation makes sense to me, but I'm still wondering if there's a better way to solve this. Below I've attached my attempt at this. I think it correctly handles dynamic registration and serialization/deserialization without blocking external model or requiring overly complex type annotations.

Please let me know what you think! (Also for simplicity I just created a simple registration system but we could probably get this to work with the PydanticClassRegistryMixin)

from collections.abc import Mapping
from typing import Annotated, Literal

from pydantic import BaseModel, BeforeValidator, SerializeAsAny

registry = {}


def register(type_):
    def decorator(cls):
        registry[type_] = cls
        return cls

    return decorator


@register("base")
class BaseExample(BaseModel):
    type_: Literal["base"] = "base"
    shared_field: str = ""

    @classmethod
    def create(cls, type_: str, **kwargs) -> "BaseExample":
        example_class = registry[type_]
        return example_class(**kwargs)


@register("nested_internal")
class NestedExample(BaseExample):
    type_: Literal["nested_internal"] = "nested_internal"
    extra_field: int = 1


def validate_subclass(obj):
    type_ = obj["type_"] if isinstance(obj, Mapping) else obj.type_
    return registry[type_].model_validate(obj)


class ParentExample(BaseModel):
    nested: Annotated[SerializeAsAny[BaseExample], BeforeValidator(validate_subclass)] # here


@register("nested_external")
class ExternalNestedExample(BaseExample):
    type_: Literal["nested_external"] = "nested_external"
    different_field: bool = True


def test_marshalling():
    nested = BaseExample.create(type_="nested_internal")
    parent = ParentExample(nested=nested)
    print(parent.nested)  # NestedExample fields

    external = BaseExample.create(type_="nested_external")
    parent_ext = ParentExample(nested=external)
    print(parent_ext.nested)  # ExternalNestedExample fields

    parent_ext_rec = ParentExample.model_validate(parent_ext.model_dump())
    print(parent_ext_rec.nested)  # ExternalNestedExample fields


if __name__ == "__main__":
    test_marshalling()

In the line marked # here, you'll see the main part of the approach is:

  1. SerializeAsAny which ensures the full class gets dumped not just the BaseExample fields Serialization of subclasses that have extra fields. pydantic/pydantic#9897
  2. BeforeValidator(validate_subclass) which dynamically maps the loaded obj to the right subclass

@markurtz
Copy link
Collaborator Author

Adding in here as well, @fynnsu, it is a simpler implementation, but with that, we remove several core things that will lead to unexpected bugs, performance issues, and general docs/understanding issues. The core breaks down to the following:

  1. There is no native schema support for Pydantic through this pathway. Everything is treated as "Any" and it is on our custom API implementation on pydantic to own. The issue here is not the amount of code, it is the native support within pydantic where we are essentially throwing away the powerful built-ins it has for minimal checks and completely lose schema generation ability.
  2. The container classes referencing the instance need to reference the annotated type rather than including the base class. This is quickly confusing for anyone that wasn't around/involved in this specific API decision, and much worse, it is a silent error if they forget to use the annotated class and instead use the base class.
  3. A lesser issue currently for speculators, but because we are injecting Python for loops and things like that, we lose the performant serialization and deserialization within Pydantic which is on the hot path at runtime rather than compile time/initialization time.

I have gone through and pushed an update to build the dependency tree and target only the chains that would be affected. Please take a look through.

@markurtz markurtz requested a review from rahul-tuli September 16, 2025 17:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants