Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
class SQLAlchemyConnectionField(ConnectionField):
@property
def type(self):
from .types import SQLAlchemyObjectType
from .types import SQLAlchemyBase

type_ = super(ConnectionField, self).type
nullable_type = get_nullable_type(type_)
if issubclass(nullable_type, Connection):
return type_
assert issubclass(nullable_type, SQLAlchemyObjectType), (
"SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
assert issubclass(nullable_type, SQLAlchemyBase), (
"SQLALchemyConnectionField only accepts SQLAlchemyBase types, not {}"
).format(nullable_type.__name__)
assert nullable_type.connection, "The type {} doesn't have a connection".format(
nullable_type.__name__
Expand Down
11 changes: 3 additions & 8 deletions graphene_sqlalchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,10 @@ def get_graphene_enum_for_sa_enum(self, sa_enum: SQLAlchemyEnumType):
return self._registry_enums.get(sa_enum)

def register_sort_enum(self, obj_type, sort_enum: Enum):
from .types import SQLAlchemyBase

from .types import SQLAlchemyObjectType

if not isinstance(obj_type, type) or not issubclass(
obj_type, SQLAlchemyObjectType
):
raise TypeError(
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
)
if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyBase):
raise TypeError("Expected SQLAlchemyBase, but got: {!r}".format(obj_type))
if not isinstance(sort_enum, type(Enum)):
raise TypeError("Expected Graphene Enum, but got: {!r}".format(sort_enum))
self._registry_sort_enums[obj_type] = sort_enum
Expand Down
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def resolver(_obj, _info):


def test_type_assert_sqlalchemy_object_type():
with pytest.raises(AssertionError, match="only accepts SQLAlchemyObjectType"):
with pytest.raises(AssertionError, match="only accepts SQLAlchemyBase types"):
SQLAlchemyConnectionField(ObjectType).type


Expand Down
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class Meta:
[("ID", EnumValue("id", Pet.id)), ("NAME", EnumValue("name", Pet.name))],
)

re_err = r"Expected SQLAlchemyObjectType, but got: .*PetSort.*"
re_err = r"Expected SQLAlchemyBase, but got: .*PetSort.*"
with pytest.raises(TypeError, match=re_err):
reg.register_sort_enum(sort_enum, sort_enum)

Expand Down