From 3ae57251d137c69c24b5736206f8b8748d5e7f00 Mon Sep 17 00:00:00 2001 From: Adithya Narasinghe Date: Mon, 5 Dec 2022 22:35:16 +0530 Subject: [PATCH 01/10] Add support for directly querying the DB --- labml_db/driver/__init__.py | 5 ++++- labml_db/driver/mongo.py | 20 ++++++++++++++++---- labml_db/model.py | 8 +++++++- labml_db/types.py | 2 ++ 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/labml_db/driver/__init__.py b/labml_db/driver/__init__.py index 20d9500..69ce3a6 100644 --- a/labml_db/driver/__init__.py +++ b/labml_db/driver/__init__.py @@ -1,6 +1,6 @@ from typing import List, Type, TYPE_CHECKING, Optional -from ..types import ModelDict +from ..types import ModelDict, QueryDict, SortDict if TYPE_CHECKING: from .. import Serializer, Model @@ -28,3 +28,6 @@ def msave_dict(self, key: List[str], data: List[ModelDict]): def get_all(self) -> List[str]: raise NotImplementedError + + def get_by_dict(self, query: Optional[QueryDict], sort: Optional[SortDict]) -> List[Optional[ModelDict]]: + raise NotImplementedError diff --git a/labml_db/driver/mongo.py b/labml_db/driver/mongo.py index 1777764..69419c6 100644 --- a/labml_db/driver/mongo.py +++ b/labml_db/driver/mongo.py @@ -1,12 +1,12 @@ -from typing import List, Type, TYPE_CHECKING, Optional, Dict +from typing import List, Type, TYPE_CHECKING, Optional, Dict, Tuple -from labml_db.serializer.utils import encode_keys, decode_keys +import pymongo +from labml_db.serializer.utils import encode_keys, decode_keys from . import DbDriver -from ..types import ModelDict +from ..types import ModelDict, QueryDict, SortDict if TYPE_CHECKING: - import pymongo from ..model import Model @@ -70,3 +70,15 @@ def get_all(self) -> List[str]: cur = self._collection.find(projection=['_id']) keys = [self._to_key(d['_id']) for d in cur] return keys + + def get_by_dict(self, query: Optional[QueryDict], sort: Optional[SortDict]) -> List[Tuple[str, ModelDict]]: + query = {k: v for k, v in query.items()} if query else dict() + cursor = self._collection.find(query) + if sort is not None and len(sort) > 0: + sort_query = [(k, pymongo.ASCENDING if v else pymongo.DESCENDING) for k, v in sort.items()] + cursor.sort(sort_query) + res = [] + for d in cursor: + res.append((d['_id'], self._load_data(d))) + + return res diff --git a/labml_db/model.py b/labml_db/model.py index 08a1760..51b60a5 100644 --- a/labml_db/model.py +++ b/labml_db/model.py @@ -3,7 +3,7 @@ from typing import Generic, Union, Any from typing import TypeVar, List, Dict, Type, Set, Optional, _GenericAlias, TYPE_CHECKING -from .types import Primitive, ModelDict +from .types import Primitive, ModelDict, QueryDict, SortDict if TYPE_CHECKING: from .driver import DbDriver @@ -340,3 +340,9 @@ def __repr__(self): kv = [f'{k}={repr(v)}' for k, v in self._values.items()] kv = ', '.join(kv) return f'{self.__class__.__name__}({kv})' + + @classmethod + def get_by(cls, query: Optional[QueryDict] = None, sort: Optional[SortDict] = None): + db_driver = Model.__db_drivers[cls.__name__] + data = db_driver.get_by_dict(query=query, sort=sort) + return [Model._to_model(k, d) for k, d in data] diff --git a/labml_db/types.py b/labml_db/types.py index 0ac16d8..269b6fd 100644 --- a/labml_db/types.py +++ b/labml_db/types.py @@ -2,4 +2,6 @@ Primitive = Union[Dict[str, 'Primitive'], List['Primitive'], int, str, float, bool, None] ModelDict = Dict[str, Primitive] +QueryDict = Dict[str, Union[int, str, float, bool]] +SortDict = Dict[str, bool] From e1d65763440cd1f51731a8ecc93496fd69329aa5 Mon Sep 17 00:00:00 2001 From: Adithya Narasinghe Date: Wed, 7 Dec 2022 12:12:56 +0530 Subject: [PATCH 02/10] Add advanced querying, text searching, randomizing and limiting --- labml_db/driver/__init__.py | 6 +++-- labml_db/driver/mongo.py | 54 +++++++++++++++++++++++++++++-------- labml_db/model.py | 24 +++++++++++++---- labml_db/types.py | 5 ++-- 4 files changed, 68 insertions(+), 21 deletions(-) diff --git a/labml_db/driver/__init__.py b/labml_db/driver/__init__.py index 69ce3a6..9bc2626 100644 --- a/labml_db/driver/__init__.py +++ b/labml_db/driver/__init__.py @@ -1,4 +1,4 @@ -from typing import List, Type, TYPE_CHECKING, Optional +from typing import List, Type, TYPE_CHECKING, Optional, Tuple from ..types import ModelDict, QueryDict, SortDict @@ -29,5 +29,7 @@ def msave_dict(self, key: List[str], data: List[ModelDict]): def get_all(self) -> List[str]: raise NotImplementedError - def get_by_dict(self, query: Optional[QueryDict], sort: Optional[SortDict]) -> List[Optional[ModelDict]]: + def search(self, text_query: Optional[str], filters: Optional[QueryDict], sort: Optional[SortDict], + randomize: bool = False, limit: Optional[int] = None, sort_by_text_score: bool = False) -> Tuple[ + List[Tuple[str, ModelDict]], int]: raise NotImplementedError diff --git a/labml_db/driver/mongo.py b/labml_db/driver/mongo.py index 69419c6..cb4d7be 100644 --- a/labml_db/driver/mongo.py +++ b/labml_db/driver/mongo.py @@ -1,4 +1,5 @@ -from typing import List, Type, TYPE_CHECKING, Optional, Dict, Tuple +from collections import OrderedDict +from typing import List, Type, TYPE_CHECKING, Optional, Dict, Tuple, OrderedDict import pymongo @@ -71,14 +72,45 @@ def get_all(self) -> List[str]: keys = [self._to_key(d['_id']) for d in cur] return keys - def get_by_dict(self, query: Optional[QueryDict], sort: Optional[SortDict]) -> List[Tuple[str, ModelDict]]: - query = {k: v for k, v in query.items()} if query else dict() - cursor = self._collection.find(query) - if sort is not None and len(sort) > 0: - sort_query = [(k, pymongo.ASCENDING if v else pymongo.DESCENDING) for k, v in sort.items()] - cursor.sort(sort_query) + def search(self, text_query: Optional[str], filters: Optional[QueryDict], sort: Optional[SortDict], + randomize: bool = False, limit: Optional[int] = None, sort_by_text_score: bool = False) -> Tuple[ + List[Tuple[str, ModelDict]], int]: + pipeline = [] + + match = {k: v for k, v in filters.items()} if filters else dict() + if text_query: + match['$text'] = {'$search': text_query} + if len(match) > 0: + pipeline.append({'$match': match}) + + if randomize: + pipeline.append({'$facet': {'data': [{'$sample': {'size': limit}}], 'count': [{'$count': 'count'}]}}) + else: + sort_query = OrderedDict() + if sort_by_text_score: + sort_query['score'] = {'$meta': 'textScore'} + if sort is not None and len(sort) > 0: + for k, v in sort: + sort_query[k] = pymongo.ASCENDING if v else pymongo.DESCENDING + + pipeline.append({'$sort': sort_query}) + + if limit: + pipeline.append({'$facet': {'data': [{'$limit': limit}], 'count': [{'$count': 'count'}]}}) + + cursor = self._collection.aggregate(pipeline) res = [] - for d in cursor: - res.append((d['_id'], self._load_data(d))) - - return res + count = 0 + if limit: + for item in cursor: + for c in item['count']: + count += c['count'] + for d in item['data']: + res.append((d['_id'], self._load_data(d))) + else: + for d in cursor: + res.append((d['_id'], self._load_data(d))) + + count = len(res) + + return res, count diff --git a/labml_db/model.py b/labml_db/model.py index 51b60a5..47e4d14 100644 --- a/labml_db/model.py +++ b/labml_db/model.py @@ -1,6 +1,6 @@ import copy import warnings -from typing import Generic, Union, Any +from typing import Generic, Union, Any, Tuple from typing import TypeVar, List, Dict, Type, Set, Optional, _GenericAlias, TYPE_CHECKING from .types import Primitive, ModelDict, QueryDict, SortDict @@ -308,8 +308,8 @@ def from_dict_transform(cls, data: ModelDict) -> Dict[str, Any]: def to_dict(self) -> ModelDict: values = {} for k, v in self._values.items(): - if k not in self._defaults or self._defaults[k] != v: - values[k] = v + # TODO: exclude defaults from the saved data based on a flag + values[k] = v values = self.to_dict_transform(values) return values @@ -342,7 +342,21 @@ def __repr__(self): return f'{self.__class__.__name__}({kv})' @classmethod - def get_by(cls, query: Optional[QueryDict] = None, sort: Optional[SortDict] = None): + def search(cls, text_query: Optional[str] = None, filters: Optional[QueryDict] = None, + sort: Optional[SortDict] = None, randomize: bool = False, limit: Optional[int] = None, + sort_by_text_score: bool = False) -> Tuple[List[_KT], int]: + if sort is not None and len(sort) > 0 and randomize: + raise ValueError('Cannot have both randomize and sort criteria') + if limit is None or limit <= 0: + raise ValueError('Limit should be higher than 0') + if randomize and not limit: + raise ValueError('A limit should be provided when results are randomized') + if sort_by_text_score and not text_query: + raise ValueError("Cannot search by text score when there's no text query") + if randomize and sort_by_text_score: + raise ValueError('Cannot have both randomize and sort by text score') + db_driver = Model.__db_drivers[cls.__name__] - data = db_driver.get_by_dict(query=query, sort=sort) + data, total_count = db_driver.search(text_query=text_query, filters=filters, sort=sort, randomize=randomize, + limit=limit, sort_by_text_score=sort_by_text_score) return [Model._to_model(k, d) for k, d in data] diff --git a/labml_db/types.py b/labml_db/types.py index 269b6fd..c1d992a 100644 --- a/labml_db/types.py +++ b/labml_db/types.py @@ -1,7 +1,6 @@ -from typing import List, Dict, Union +from typing import List, Dict, Union, Tuple Primitive = Union[Dict[str, 'Primitive'], List['Primitive'], int, str, float, bool, None] ModelDict = Dict[str, Primitive] QueryDict = Dict[str, Union[int, str, float, bool]] -SortDict = Dict[str, bool] - +SortDict = List[Tuple[str, bool]] From 4c12ae612044d85745b826195ba9e48ee9c17d0a Mon Sep 17 00:00:00 2001 From: Adithya Narasinghe Date: Wed, 7 Dec 2022 14:18:01 +0530 Subject: [PATCH 03/10] bugfix --- labml_db/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/labml_db/model.py b/labml_db/model.py index 47e4d14..912148d 100644 --- a/labml_db/model.py +++ b/labml_db/model.py @@ -359,4 +359,4 @@ def search(cls, text_query: Optional[str] = None, filters: Optional[QueryDict] = db_driver = Model.__db_drivers[cls.__name__] data, total_count = db_driver.search(text_query=text_query, filters=filters, sort=sort, randomize=randomize, limit=limit, sort_by_text_score=sort_by_text_score) - return [Model._to_model(k, d) for k, d in data] + return [Model._to_model(k, d) for k, d in data], total_count From fd18cde5db0240280db2b4fe233b7f7e43edddf8 Mon Sep 17 00:00:00 2001 From: Adithya Narasinghe Date: Thu, 8 Dec 2022 03:45:36 +0530 Subject: [PATCH 04/10] Support not equal queries --- labml_db/driver/mongo.py | 9 ++++++++- labml_db/model.py | 2 +- labml_db/types.py | 3 ++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/labml_db/driver/mongo.py b/labml_db/driver/mongo.py index cb4d7be..49ea6dc 100644 --- a/labml_db/driver/mongo.py +++ b/labml_db/driver/mongo.py @@ -77,7 +77,14 @@ def search(self, text_query: Optional[str], filters: Optional[QueryDict], sort: List[Tuple[str, ModelDict]], int]: pipeline = [] - match = {k: v for k, v in filters.items()} if filters else dict() + match = dict() + if filters: + for property_name, item in filters.items(): + value, equal = item + if equal: + match[property_name] = value + else: + match[property_name] = {'$ne': value} if text_query: match['$text'] = {'$search': text_query} if len(match) > 0: diff --git a/labml_db/model.py b/labml_db/model.py index 912148d..93f78f3 100644 --- a/labml_db/model.py +++ b/labml_db/model.py @@ -347,7 +347,7 @@ def search(cls, text_query: Optional[str] = None, filters: Optional[QueryDict] = sort_by_text_score: bool = False) -> Tuple[List[_KT], int]: if sort is not None and len(sort) > 0 and randomize: raise ValueError('Cannot have both randomize and sort criteria') - if limit is None or limit <= 0: + if limit is not None and limit <= 0: raise ValueError('Limit should be higher than 0') if randomize and not limit: raise ValueError('A limit should be provided when results are randomized') diff --git a/labml_db/types.py b/labml_db/types.py index c1d992a..0dc1ffb 100644 --- a/labml_db/types.py +++ b/labml_db/types.py @@ -2,5 +2,6 @@ Primitive = Union[Dict[str, 'Primitive'], List['Primitive'], int, str, float, bool, None] ModelDict = Dict[str, Primitive] -QueryDict = Dict[str, Union[int, str, float, bool]] +# {Property: (value, equal/not_equal)} +QueryDict = Dict[str, Tuple[Union[List['Primitive'], int, str, float, bool], bool]] SortDict = List[Tuple[str, bool]] From cdae70f18b5c3f4486da0f53c6cbd72bf4af9153 Mon Sep 17 00:00:00 2001 From: Nipun Wijerathne Date: Sat, 10 Dec 2022 08:19:54 +0530 Subject: [PATCH 05/10] Update mongo.py --- labml_db/driver/mongo.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/labml_db/driver/mongo.py b/labml_db/driver/mongo.py index 49ea6dc..7f71c60 100644 --- a/labml_db/driver/mongo.py +++ b/labml_db/driver/mongo.py @@ -99,8 +99,9 @@ def search(self, text_query: Optional[str], filters: Optional[QueryDict], sort: if sort is not None and len(sort) > 0: for k, v in sort: sort_query[k] = pymongo.ASCENDING if v else pymongo.DESCENDING - - pipeline.append({'$sort': sort_query}) + + if len(sort_query) > 0: + pipeline.append({'$sort': sort_query}) if limit: pipeline.append({'$facet': {'data': [{'$limit': limit}], 'count': [{'$count': 'count'}]}}) From cdd840f426942deda37239d0ed23213eba629571 Mon Sep 17 00:00:00 2001 From: Adithya Narasinghe Date: Tue, 13 Dec 2022 17:38:22 +0530 Subject: [PATCH 06/10] Make sure defaults are stored during model instance creation --- labml_db/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/labml_db/model.py b/labml_db/model.py index 93f78f3..8f3346a 100644 --- a/labml_db/model.py +++ b/labml_db/model.py @@ -192,6 +192,10 @@ def __init__(self, key: Optional[str] = None, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) + for k, v in self._defaults: + if k not in kwargs: + setattr(self, k, v) + def __init_subclass__(cls, **kwargs): if cls.__name__ in Model.__models: warnings.warn(f"{cls.__name__} already used") From 60c035814f3db0a4ef91cea7bae8089943901455 Mon Sep 17 00:00:00 2001 From: Adithya Narasinghe Date: Tue, 13 Dec 2022 17:40:50 +0530 Subject: [PATCH 07/10] bugfix --- labml_db/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/labml_db/model.py b/labml_db/model.py index 8f3346a..ab08383 100644 --- a/labml_db/model.py +++ b/labml_db/model.py @@ -192,7 +192,7 @@ def __init__(self, key: Optional[str] = None, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) - for k, v in self._defaults: + for k, v in self._defaults.items(): if k not in kwargs: setattr(self, k, v) From 6bb4b89e2200c64ab4961ec52866c3a9f4dfc676 Mon Sep 17 00:00:00 2001 From: hnipun Date: Sat, 26 Aug 2023 16:40:36 +0530 Subject: [PATCH 08/10] check for empty before writing to the file --- labml_db/serializer/json.py | 1 + labml_db/serializer/yaml.py | 1 + 2 files changed, 2 insertions(+) diff --git a/labml_db/serializer/json.py b/labml_db/serializer/json.py index 13ff059..20d69a8 100644 --- a/labml_db/serializer/json.py +++ b/labml_db/serializer/json.py @@ -10,6 +10,7 @@ class JsonSerializer(Serializer): file_extension = 'json' def to_string(self, data: ModelDict) -> str: + assert data return json.dumps(encode_keys(data)) def from_string(self, data: Optional[str]) -> Optional[ModelDict]: diff --git a/labml_db/serializer/yaml.py b/labml_db/serializer/yaml.py index cb7a610..c58d2bd 100644 --- a/labml_db/serializer/yaml.py +++ b/labml_db/serializer/yaml.py @@ -10,6 +10,7 @@ class YamlSerializer(Serializer): def to_string(self, data: ModelDict) -> str: import yaml + assert data return yaml.dump(encode_keys(data), default_flow_style=False) def from_string(self, data: Optional[str]) -> Optional[ModelDict]: From 7cf6b2d80d69c15fede260800ab19502e109e202 Mon Sep 17 00:00:00 2001 From: hnipun Date: Tue, 14 Nov 2023 13:53:14 +0530 Subject: [PATCH 09/10] check for empty before writing to the file --- labml_db/driver/mongo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/labml_db/driver/mongo.py b/labml_db/driver/mongo.py index 7f71c60..39b3803 100644 --- a/labml_db/driver/mongo.py +++ b/labml_db/driver/mongo.py @@ -3,7 +3,7 @@ import pymongo -from labml_db.serializer.utils import encode_keys, decode_keys +from ..serializer.utils import encode_keys, decode_keys from . import DbDriver from ..types import ModelDict, QueryDict, SortDict From 36f8edb83cefc3dacdeb07fb1ec16d9c0153a08c Mon Sep 17 00:00:00 2001 From: hnipun Date: Tue, 14 Nov 2023 13:56:18 +0530 Subject: [PATCH 10/10] check for empty before writing to the file --- labml_db/serializer/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/labml_db/serializer/utils.py b/labml_db/serializer/utils.py index d6b230b..8d52809 100644 --- a/labml_db/serializer/utils.py +++ b/labml_db/serializer/utils.py @@ -1,7 +1,7 @@ from typing import Dict -from labml_db import Key -from labml_db.types import Primitive +from .. import Key +from ..types import Primitive def encode_key(key: Key) -> Dict[str, str]: