-
Notifications
You must be signed in to change notification settings - Fork 46
refactor pgstacDb class #375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
"""pyPgSTAC Version.""" | ||
|
||
from pypgstac.version import __version__ | ||
|
||
__all__ = ["__version__"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
"""Base library for database interaction with PgSTAC.""" | ||
import atexit | ||
|
||
import contextlib | ||
import logging | ||
import time | ||
from dataclasses import dataclass, field | ||
from types import TracebackType | ||
from typing import Any, Generator, List, Optional, Tuple, Type, Union | ||
|
||
|
@@ -52,37 +54,24 @@ class Settings(BaseSettings): | |
settings = Settings() | ||
|
||
|
||
@dataclass | ||
class PgstacDB: | ||
"""Base class for interacting with PgSTAC Database.""" | ||
|
||
def __init__( | ||
self, | ||
dsn: Optional[str] = "", | ||
pool: Optional[ConnectionPool] = None, | ||
connection: Optional[Connection] = None, | ||
commit_on_exit: bool = True, | ||
debug: bool = False, | ||
use_queue: bool = False, | ||
) -> None: | ||
"""Initialize Database.""" | ||
self.dsn: str | ||
if dsn is not None: | ||
self.dsn = dsn | ||
else: | ||
self.dsn = "" | ||
self.pool = pool | ||
self.connection = connection | ||
self.commit_on_exit = commit_on_exit | ||
self.initial_version = "0.1.9" | ||
self.debug = debug | ||
self.use_queue = use_queue | ||
if self.debug: | ||
logging.basicConfig(level=logging.DEBUG) | ||
dsn: str | ||
commit_on_exit: bool = True | ||
debug: bool = False | ||
use_queue: bool = False | ||
|
||
def get_pool(self) -> ConnectionPool: | ||
"""Get Database Pool.""" | ||
if self.pool is None: | ||
self.pool = ConnectionPool( | ||
pool: ConnectionPool = field(default=None) | ||
|
||
initial_version: str = field(init=False, default="0.1.9") | ||
|
||
_pool: ConnectionPool = field(init=False) | ||
|
||
def __post_init__(self): | ||
if not self.pool: | ||
self._pool = ConnectionPool( | ||
conninfo=self.dsn, | ||
min_size=settings.db_min_conn_size, | ||
max_size=settings.db_max_conn_size, | ||
|
@@ -91,36 +80,51 @@ def get_pool(self) -> ConnectionPool: | |
num_workers=settings.db_num_workers, | ||
open=True, | ||
) | ||
return self.pool | ||
|
||
def open(self) -> None: | ||
"""Open database pool connection.""" | ||
self.get_pool() | ||
def get_pool(self) -> ConnectionPool: | ||
"""Get Database Pool.""" | ||
return self.pool or self._pool | ||
|
||
def close(self) -> None: | ||
"""Close database pool connection.""" | ||
if self.pool is not None: | ||
self.pool.close() | ||
if self._pool is not None: | ||
self._pool.close() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only close |
||
|
||
def __enter__(self) -> Any: | ||
"""Enter used for context.""" | ||
return self | ||
|
||
def __exit__( | ||
self, | ||
exc_type: Optional[Type[BaseException]], | ||
exc: Optional[BaseException], | ||
traceback: Optional[TracebackType], | ||
) -> None: | ||
"""Exit used for context.""" | ||
self.close() | ||
|
||
@contextlib.contextmanager | ||
def connect(self) -> Connection: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
with PgstacDb(dsn=...) as db:
with db.connect() as conn:
... |
||
"""Return database connection.""" | ||
pool = self.get_pool() | ||
if self.connection is None: | ||
self.connection = pool.getconn() | ||
self.connection.autocommit = True | ||
|
||
conn: Connection = None | ||
try: | ||
conn = pool.getconn() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could also switch to with pool.connection() as conn: |
||
conn.autocommit = True | ||
if self.debug: | ||
self.connection.add_notice_handler(pg_notice_handler) | ||
self.connection.execute( | ||
conn.add_notice_handler(pg_notice_handler) | ||
conn.execute( | ||
"SET CLIENT_MIN_MESSAGES TO NOTICE;", | ||
prepare=False, | ||
) | ||
if self.use_queue: | ||
self.connection.execute( | ||
conn.execute( | ||
"SET pgstac.use_queue TO TRUE;", | ||
prepare=False, | ||
) | ||
atexit.register(self.disconnect) | ||
self.connection.execute( | ||
|
||
conn.execute( | ||
""" | ||
SELECT | ||
CASE | ||
|
@@ -138,54 +142,25 @@ def connect(self) -> Connection: | |
""", | ||
prepare=False, | ||
) | ||
return self.connection | ||
with conn: | ||
yield conn | ||
|
||
finally: | ||
if conn: | ||
pool.putconn(conn) | ||
|
||
def wait(self) -> None: | ||
"""Block until database connection is ready.""" | ||
cnt: int = 0 | ||
while cnt < 60: | ||
try: | ||
self.connect() | ||
self.query("SELECT 1;") | ||
return None | ||
except psycopg.errors.OperationalError: | ||
time.sleep(1) | ||
cnt += 1 | ||
raise psycopg.errors.CannotConnectNow | ||
|
||
def disconnect(self) -> None: | ||
"""Disconnect from database.""" | ||
try: | ||
if self.connection is not None: | ||
if self.commit_on_exit: | ||
self.connection.commit() | ||
else: | ||
self.connection.rollback() | ||
except Exception: | ||
pass | ||
try: | ||
if self.pool is not None and self.connection is not None: | ||
self.pool.putconn(self.connection) | ||
except Exception: | ||
pass | ||
|
||
self.connection = None | ||
self.pool = None | ||
|
||
def __enter__(self) -> Any: | ||
"""Enter used for context.""" | ||
self.connect() | ||
return self | ||
|
||
def __exit__( | ||
self, | ||
exc_type: Optional[Type[BaseException]], | ||
exc: Optional[BaseException], | ||
traceback: Optional[TracebackType], | ||
) -> None: | ||
"""Exit used for context.""" | ||
self.disconnect() | ||
|
||
@retry( | ||
stop=stop_after_attempt(settings.db_retries), | ||
retry=retry_if_exception_type(psycopg.errors.OperationalError), | ||
|
@@ -198,30 +173,27 @@ def query( | |
row_factory: psycopg.rows.BaseRowFactory = psycopg.rows.tuple_row, | ||
) -> Generator: | ||
"""Query the database with parameters.""" | ||
conn = self.connect() | ||
try: | ||
with conn.cursor(row_factory=row_factory) as cursor: | ||
if args is None: | ||
rows = cursor.execute(query, prepare=False) | ||
else: | ||
rows = cursor.execute(query, args) | ||
if rows: | ||
for row in rows: | ||
yield row | ||
else: | ||
yield None | ||
except psycopg.errors.OperationalError as e: | ||
# If we get an operational error check the pool and retry | ||
logger.warning(f"OPERATIONAL ERROR: {e}") | ||
if self.pool is None: | ||
self.get_pool() | ||
else: | ||
self.pool.check() | ||
raise e | ||
except psycopg.errors.DatabaseError as e: | ||
if conn is not None: | ||
conn.rollback() | ||
raise e | ||
with self.connect() as conn: | ||
try: | ||
with conn.cursor(row_factory=row_factory) as cursor: | ||
if args is None: | ||
rows = cursor.execute(query, prepare=False) | ||
else: | ||
rows = cursor.execute(query, args) | ||
if rows: | ||
for row in rows: | ||
yield row | ||
else: | ||
yield None | ||
except psycopg.errors.OperationalError as e: | ||
# If we get an operational error check the pool and retry | ||
logger.warning(f"OPERATIONAL ERROR: {e}") | ||
self._pool.check() | ||
raise e | ||
except psycopg.errors.DatabaseError as e: | ||
if conn is not None: | ||
conn.rollback() | ||
raise e | ||
|
||
def query_one(self, *args: Any, **kwargs: Any) -> Union[Tuple, str, None]: | ||
"""Return results from a query that returns a single row.""" | ||
|
@@ -238,10 +210,9 @@ def query_one(self, *args: Any, **kwargs: Any) -> Union[Tuple, str, None]: | |
|
||
def run_queued(self) -> str: | ||
try: | ||
self.connect().execute(""" | ||
CALL run_queued_queries(); | ||
""") | ||
return "Ran Queued Queries" | ||
with self.connect() as conn: | ||
conn.execute("CALL run_queued_queries();") | ||
return "Ran Queued Queries" | ||
except Exception as e: | ||
return f"Error Running Queued Queries: {e}" | ||
|
||
|
@@ -262,8 +233,6 @@ def version(self) -> Optional[str]: | |
return version | ||
except psycopg.errors.UndefinedTable: | ||
logger.debug("PgSTAC is not installed.") | ||
if self.connection is not None: | ||
self.connection.rollback() | ||
return None | ||
|
||
@property | ||
|
@@ -280,13 +249,13 @@ def pg_version(self) -> str: | |
if isinstance(version, str): | ||
if int(version) < 130000: | ||
major, minor, patch = tuple( | ||
map(int, [version[i:i + 2] for i in range(0, len(version), 2)]), | ||
map(int, [version[i : i + 2] for i in range(0, len(version), 2)]), | ||
) | ||
raise Exception(f"PgSTAC requires PostgreSQL 13+, current version is: {major}.{minor}.{patch}") # noqa: E501 | ||
raise Exception( | ||
f"PgSTAC requires PostgreSQL 13+, current version is: {major}.{minor}.{patch}", | ||
) # noqa: E501 | ||
return version | ||
else: | ||
if self.connection is not None: | ||
self.connection.rollback() | ||
raise Exception("Could not find PG version.") | ||
|
||
def func(self, function_name: str, *args: Any) -> Generator: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed possibility to pass a
Connection
object which adds to much complexity