Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions pkgs/base/swarmauri_base/DynamicBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,16 +586,18 @@ def decorator(model_cls: Type[BaseModel]):
@classmethod
def register_type(
cls,
resource_type: Optional[Union[Type[T], List[Type[T]]]] = None,
resource_type: Optional[
Union[Type[T], List[Type[T]], Tuple[Type[T], ...]]
] = None,
type_name: Optional[str] = None,
):
"""
Decorator to register a subtype under one or more base models in the unified registry.

Parameters:
resource_type (Optional[Union[Type[T], List[Type[T]]]]):
The base model(s) under which to register the subtype. If None, all direct base classes (except DynamicBase)
are used.
resource_type (Optional[Union[Type[T], List[Type[T]], Tuple[Type[T], ...]]]):
The base model(s) under which to register the subtype. If ``None``, all
direct base classes (except ``DynamicBase``) are used.
type_name (Optional[str]): An optional custom type name for the subtype.

Returns:
Expand All @@ -608,10 +610,10 @@ def decorator(subclass: Type["DynamicBase"]):
resource_types = [
base for base in subclass.__bases__ if base is not cls
]
elif not isinstance(resource_type, list):
resource_types = [resource_type]
elif isinstance(resource_type, (list, tuple)):
resource_types = list(resource_type)
else:
resource_types = resource_type
resource_types = [resource_type]

for rt in resource_types:
if not issubclass(subclass, rt):
Expand Down
29 changes: 29 additions & 0 deletions pkgs/base/tests/unit/DynamicBase_multi_resource_unit_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Unit tests for registering to multiple resource types."""

import pytest

from swarmauri_base.DynamicBase import DynamicBase


@DynamicBase.register_model()
class ResourceTypeA(DynamicBase):
"""Dummy base model A."""


@DynamicBase.register_model()
class ResourceTypeB(DynamicBase):
"""Dummy base model B."""


@DynamicBase.register_type(resource_type=(ResourceTypeA, ResourceTypeB))
class MultiResource(ResourceTypeA, ResourceTypeB):
"""Model registered to two resource types."""


@pytest.mark.unit
def test_register_multiple_resource_types():
"""Ensure a model registers under each specified resource type."""
reg_a = DynamicBase._registry["ResourceTypeA"]["subtypes"]
reg_b = DynamicBase._registry["ResourceTypeB"]["subtypes"]
assert reg_a["MultiResource"] is MultiResource
assert reg_b["MultiResource"] is MultiResource