diff --git a/docs/settings.rst b/docs/settings.rst index 5d8482104..7df6907cc 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -78,6 +78,26 @@ Default: automatically constructed by boto to account for region The URL endpoint for DynamoDB. This can be used to use a local implementation of DynamoDB such as DynamoDB Local or dynalite. +dax_write_endpoints +------------------ + +Default: ``[]`` + +Connect to DAX endpoints for write operations. + +Supported Operations: PutItem, DeleteItem, UpdateItem, BatchWriteItem + + +dax_read_endpoints +------------------ + +Default: ``[]`` + +Connect to DAX endpoints for read operations. + +Supported Operations: GetItem, Scan, BatchGetItem, Query + + Overriding settings ~~~~~~~~~~~~~~~~~~~ diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index cca050f7d..0412b30a7 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -48,6 +48,7 @@ VerboseClientError, TransactGetError, TransactWriteError) from pynamodb.expressions.condition import Condition +from pynamodb.connection.dax import DaxClient, OP_READ, OP_WRITE from pynamodb.expressions.operand import Path from pynamodb.expressions.projection import create_projection_expression from pynamodb.expressions.update import Action, Update @@ -247,11 +248,15 @@ def __init__(self, max_retry_attempts: Optional[int] = None, base_backoff_ms: Optional[int] = None, max_pool_connections: Optional[int] = None, - extra_headers: Optional[Mapping[str, str]] = None): + extra_headers: Optional[Mapping[str, str]] = None, + dax_write_endpoints: Optional[List[str]] = None, + dax_read_endpoints: Optional[List[str]] = None, + fallback_to_dynamodb: Optional[bool] = False): self._tables: Dict[str, MetaTable] = {} self.host = host self._local = local() self._client = None + if region: self.region = region else: @@ -287,6 +292,21 @@ def __init__(self, else: self._extra_headers = get_settings_value('extra_headers') + if dax_write_endpoints is None: + dax_write_endpoints = get_settings_value('dax_write_endpoints') + + if dax_read_endpoints is None: + dax_read_endpoints = get_settings_value('dax_read_endpoints') + + self._dax_support = bool(dax_write_endpoints or dax_read_endpoints) + self._dax_read_client = None if not dax_read_endpoints else DaxClient(endpoints=dax_read_endpoints, region_name=self.region) + self._dax_write_client = None if not dax_write_endpoints else DaxClient(endpoints=dax_write_endpoints, region_name=self.region) + + if fallback_to_dynamodb is not None: + self._fallback_to_dynamodb = fallback_to_dynamodb + else: + self._fallback_to_dynamodb = get_settings_value('fallback_to_dynamodb') + def __repr__(self) -> str: return "Connection<{}>".format(self.client.meta.endpoint_url) @@ -354,6 +374,18 @@ def _make_api_call(self, operation_name: str, operation_kwargs: Dict, settings: 1. It's faster to avoid using botocore's response parsing 2. It provides a place to monkey patch HTTP requests for unit testing """ + if self._dax_support: + from amazondax.DaxError import DaxClientError + + try: + if operation_name in OP_WRITE and self._dax_write_client: + return self._dax_write_client.dispatch(operation_name, operation_kwargs) + elif operation_name in OP_READ and self._dax_read_client: + return self._dax_read_client.dispatch(operation_name, operation_kwargs) + except DaxClientError: + if not self._fallback_to_dynamodb: + raise + operation_model = self.client._service_model.operation_model(operation_name) request_dict = self.client._convert_to_request_dict( operation_kwargs, diff --git a/pynamodb/connection/dax.py b/pynamodb/connection/dax.py new file mode 100644 index 000000000..414c4db2a --- /dev/null +++ b/pynamodb/connection/dax.py @@ -0,0 +1,33 @@ +from typing import Dict, List + +OP_WRITE = { + 'PutItem': 'put_item', + 'DeleteItem': 'delete_item', + 'UpdateItem': 'update_item', + 'BatchWriteItem': 'batch_write_item', +} + +OP_READ = { + 'GetItem': 'get_item', + 'Scan': 'scan', + 'BatchGetItem': 'batch_get_item', + 'Query': 'query', +} + +OP_NAME_TO_METHOD = OP_WRITE.copy() +OP_NAME_TO_METHOD.update(OP_READ) + + +class DaxClient(object): + + def __init__(self, endpoints: List[str], region_name: str): + from amazondax import AmazonDaxClient + + self.connection = AmazonDaxClient( + endpoints=endpoints, + region_name=region_name + ) + + def dispatch(self, operation_name: str, operation_kwargs: Dict): + method = getattr(self.connection, OP_NAME_TO_METHOD[operation_name]) + return method(**operation_kwargs) diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 183467a9f..1a96d4f0b 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -3,7 +3,7 @@ ~~~~~~~~~~~~~~~~~~~~~~~~~~~ """ -from typing import Any, Dict, Mapping, Optional, Sequence +from typing import Any, Dict, List, Mapping, Optional, Sequence from pynamodb.connection.base import Connection, MetaTable, OperationSettings from pynamodb.constants import DEFAULT_BILLING_MODE, KEY @@ -30,6 +30,9 @@ def __init__( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, + dax_write_endpoints: Optional[List[str]] = None, + dax_read_endpoints: Optional[List[str]] = None, + fallback_to_dynamodb: Optional[bool] = False ) -> None: self.table_name = table_name self.connection = Connection(region=region, @@ -39,7 +42,10 @@ def __init__( max_retry_attempts=max_retry_attempts, base_backoff_ms=base_backoff_ms, max_pool_connections=max_pool_connections, - extra_headers=extra_headers) + extra_headers=extra_headers, + dax_write_endpoints=dax_write_endpoints, + dax_read_endpoints=dax_read_endpoints, + fallback_to_dynamodb=fallback_to_dynamodb) if aws_access_key_id and aws_secret_access_key: self.connection.session.set_credentials(aws_access_key_id, diff --git a/pynamodb/models.py b/pynamodb/models.py index 7d5e99161..4355911ba 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -258,6 +258,12 @@ def __init__(self, name, bases, namespace, discriminator=None) -> None: setattr(attr_obj, 'aws_secret_access_key', None) if not hasattr(attr_obj, 'aws_session_token'): setattr(attr_obj, 'aws_session_token', None) + if not hasattr(attr_obj, 'dax_write_endpoints'): + setattr(attr_obj, 'dax_write_endpoints', get_settings_value('dax_write_endpoints')) + if not hasattr(attr_obj, 'dax_read_endpoints'): + setattr(attr_obj, 'dax_read_endpoints', get_settings_value('dax_read_endpoints')) + if not hasattr(attr_obj, 'fallback_to_dynamodb'): + setattr(attr_obj, 'fallback_to_dynamodb', get_settings_value('fallback_to_dynamodb')) # create a custom Model.DoesNotExist derived from pynamodb.exceptions.DoesNotExist, # so that "except Model.DoesNotExist:" would not catch other models' exceptions @@ -1072,7 +1078,10 @@ def _get_connection(cls) -> TableConnection: extra_headers=cls.Meta.extra_headers, aws_access_key_id=cls.Meta.aws_access_key_id, aws_secret_access_key=cls.Meta.aws_secret_access_key, - aws_session_token=cls.Meta.aws_session_token) + aws_session_token=cls.Meta.aws_session_token, + dax_write_endpoints=cls.Meta.dax_write_endpoints, + dax_read_endpoints=cls.Meta.dax_read_endpoints, + fallback_to_dynamodb=cls.Meta.fallback_to_dynamodb) return cls._connection @classmethod diff --git a/pynamodb/settings.py b/pynamodb/settings.py index 7283dce03..78af4b347 100644 --- a/pynamodb/settings.py +++ b/pynamodb/settings.py @@ -16,6 +16,9 @@ 'region': None, 'max_pool_connections': 10, 'extra_headers': None, + 'dax_write_endpoints': [], + 'dax_read_endpoints': [], + 'fallback_to_dynamodb': False } OVERRIDE_SETTINGS_PATH = getenv('PYNAMODB_CONFIG', '/etc/pynamodb/global_default_settings.py') diff --git a/requirements-dev.txt b/requirements-dev.txt index 1b2b06323..377763dbc 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ +amazon-dax-client>=2.0.0,<3.0.0 pytest>=6 pytest-env pytest-mock diff --git a/setup.py b/setup.py index 825879c13..c8b55df76 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ ], extras_require={ 'signals': ['blinker>=1.3,<2.0'], + 'dax': ['amazon-dax-client>=2.0.0,<3.0.0'] }, package_data={'pynamodb': ['py.typed']}, )