From d2e299684ac80bd215a702c2848ef2a86a2ae03d Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 22 Feb 2022 16:30:30 +0800 Subject: [PATCH 1/4] define the queryset_class in model --- docs/declaring_models.md | 5 +++++ orm/__init__.py | 3 ++- orm/models.py | 11 ++++++++++- tests/test_models.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/docs/declaring_models.md b/docs/declaring_models.md index 1099ef4..e46964d 100644 --- a/docs/declaring_models.md +++ b/docs/declaring_models.md @@ -18,8 +18,13 @@ models = orm.ModelRegistry(database=database) class Note(orm.Model): + class MyQuerySet(QuerySet): + ... + tablename = "notes" registry = models + # or do not define the queryset_class + queryset_class = MyQuerySet fields = { "id": orm.Integer(primary_key=True), "text": orm.String(max_length=100), diff --git a/orm/__init__.py b/orm/__init__.py index ee3ef5c..553c96a 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -20,7 +20,7 @@ Text, Time, ) -from orm.models import Model, ModelRegistry +from orm.models import Model, ModelRegistry, QuerySet __version__ = "0.3.1" __all__ = [ @@ -49,4 +49,5 @@ "UUID", "Model", "ModelRegistry", + "QuerySet" ] diff --git a/orm/models.py b/orm/models.py index b402814..2075c85 100644 --- a/orm/models.py +++ b/orm/models.py @@ -83,6 +83,12 @@ def __new__(cls, name, bases, attrs): if "tablename" not in attrs: setattr(model_class, "tablename", name.lower()) + if "queryset_class" in attrs: + if not isinstance(attrs["queryset_class"], QuerySet): + raise ValueError("queryset must extend QuerySet class.") + else: + attrs["queryset_class"] = None + for name, field in attrs.get("fields", {}).items(): setattr(field, "registry", attrs.get("registry")) if field.primary_key: @@ -485,7 +491,6 @@ def _prepare_order_by(self, order_by: str): class Model(metaclass=ModelMeta): - objects = QuerySet() def __init__(self, **kwargs): if "pk" in kwargs: @@ -497,6 +502,10 @@ def __init__(self, **kwargs): ) setattr(self, key, value) + @property + def objects(self) -> QuerySet: + return cls.queryset_class() if cls.queryset_class else QuerySet() + @property def pk(self): return getattr(self, self.pkname) diff --git a/tests/test_models.py b/tests/test_models.py index 8e24437..0f213e8 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -32,6 +32,24 @@ class Product(orm.Model): } +class Book(orm.Model): + class MyQuerySet(QuerySet): + + async def get_or_none(self, **kwargs): + try: + return suepr().get(**kwargs) + except NoMatch: + return None + + tablename = "products" + registry = models + queryset_class = QuerySet + fields = { + "id": orm.Integer(primary_key=True), + "name": orm.String(max_length=100), + } + + @pytest.fixture(autouse=True, scope="function") async def create_test_database(): await models.create_all() @@ -333,3 +351,13 @@ async def test_model_sqlalchemy_filter_operators(): shirt == await Product.objects.filter(Product.columns.name.contains("Cotton")).get() ) + + +async def test_queryset_class(): + await Book.objects.create(name="book") + + b = await Book.objects.get_or_none(name="book") + assert b + + b = await Book.objects.get_or_none(name="books") + assert b is None From 3a92a6395f06870a96a55ba8d1d8854b0b4c7264 Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 22 Feb 2022 16:33:15 +0800 Subject: [PATCH 2/4] fix lint --- tests/test_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 0f213e8..c6b6bcf 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -37,8 +37,9 @@ class MyQuerySet(QuerySet): async def get_or_none(self, **kwargs): try: - return suepr().get(**kwargs) + return await super().get(**kwargs) except NoMatch: + # or raise HttpException(404) return None tablename = "products" From b615d5cf9dbe89b03ca47ce21a86b976b6765d29 Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 22 Feb 2022 16:41:53 +0800 Subject: [PATCH 3/4] fixc --- orm/models.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/orm/models.py b/orm/models.py index 2075c85..ac5534e 100644 --- a/orm/models.py +++ b/orm/models.py @@ -83,11 +83,7 @@ def __new__(cls, name, bases, attrs): if "tablename" not in attrs: setattr(model_class, "tablename", name.lower()) - if "queryset_class" in attrs: - if not isinstance(attrs["queryset_class"], QuerySet): - raise ValueError("queryset must extend QuerySet class.") - else: - attrs["queryset_class"] = None + attrs.setdefault("queryset_class", QuerySet) for name, field in attrs.get("fields", {}).items(): setattr(field, "registry", attrs.get("registry")) @@ -504,7 +500,7 @@ def __init__(self, **kwargs): @property def objects(self) -> QuerySet: - return cls.queryset_class() if cls.queryset_class else QuerySet() + return cls.queryset_class() @property def pk(self): From a12bf50fbb28007dbaf44de886476374d16f11b5 Mon Sep 17 00:00:00 2001 From: huangsong Date: Tue, 22 Feb 2022 16:55:31 +0800 Subject: [PATCH 4/4] fix queryset class in metaclass --- orm/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/orm/models.py b/orm/models.py index ac5534e..c001585 100644 --- a/orm/models.py +++ b/orm/models.py @@ -83,7 +83,7 @@ def __new__(cls, name, bases, attrs): if "tablename" not in attrs: setattr(model_class, "tablename", name.lower()) - attrs.setdefault("queryset_class", QuerySet) + model_class.queryset_class = attrs.get("queryset_class") for name, field in attrs.get("fields", {}).items(): setattr(field, "registry", attrs.get("registry")) @@ -500,7 +500,7 @@ def __init__(self, **kwargs): @property def objects(self) -> QuerySet: - return cls.queryset_class() + return self.queryset_class() if self.queryset_class else QuerySet() @property def pk(self):