Skip to content

refactor pgstacDb class #375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/pypgstac/examples/load_queryables_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@


def load_for_specific_collections(
cli, sample_file, collection_ids, delete_missing=False,
cli,
sample_file,
collection_ids,
delete_missing=False,
):
"""Load queryables for specific collections.

Expand All @@ -27,7 +30,9 @@ def load_for_specific_collections(
delete_missing: If True, delete properties not present in the file
"""
cli.load_queryables(
str(sample_file), collection_ids=collection_ids, delete_missing=delete_missing,
str(sample_file),
collection_ids=collection_ids,
delete_missing=delete_missing,
)


Expand Down Expand Up @@ -57,7 +62,10 @@ def main():
# Example of loading for specific collections with delete_missing=True
# This will delete properties not present in the file, but only for the specified collections
load_for_specific_collections(
cli, sample_file, ["landsat-8", "sentinel-2"], delete_missing=True,
cli,
sample_file,
["landsat-8", "sentinel-2"],
delete_missing=True,
)


Expand Down
3 changes: 1 addition & 2 deletions src/pypgstac/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,9 @@ select = [
"PLE",
# "PLR",
"PLW",
"COM", # flake8-commas
]
ignore = [
# "E501", # line too long, handled by black
"E501", # line too long, handled by black
"B008", # do not perform function calls in argument defaults
"C901", # too complex
"B905",
Expand Down
1 change: 1 addition & 0 deletions src/pypgstac/src/pypgstac/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""pyPgSTAC Version."""

from pypgstac.version import __version__

__all__ = ["__version__"]
191 changes: 80 additions & 111 deletions src/pypgstac/src/pypgstac/db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Base library for database interaction with PgSTAC."""
import atexit

import contextlib
import logging
import time
from dataclasses import dataclass, field
from types import TracebackType
from typing import Any, Generator, List, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -52,37 +54,24 @@ class Settings(BaseSettings):
settings = Settings()


@dataclass
class PgstacDB:
"""Base class for interacting with PgSTAC Database."""

def __init__(
self,
dsn: Optional[str] = "",
pool: Optional[ConnectionPool] = None,
connection: Optional[Connection] = None,
commit_on_exit: bool = True,
debug: bool = False,
use_queue: bool = False,
) -> None:
"""Initialize Database."""
self.dsn: str
if dsn is not None:
self.dsn = dsn
else:
self.dsn = ""
self.pool = pool
self.connection = connection
self.commit_on_exit = commit_on_exit
self.initial_version = "0.1.9"
self.debug = debug
self.use_queue = use_queue
if self.debug:
logging.basicConfig(level=logging.DEBUG)
dsn: str
commit_on_exit: bool = True
debug: bool = False
use_queue: bool = False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed possibility to pass a Connection object which adds to much complexity


def get_pool(self) -> ConnectionPool:
"""Get Database Pool."""
if self.pool is None:
self.pool = ConnectionPool(
pool: ConnectionPool = field(default=None)

initial_version: str = field(init=False, default="0.1.9")

_pool: ConnectionPool = field(init=False)

def __post_init__(self):
if not self.pool:
self._pool = ConnectionPool(
conninfo=self.dsn,
min_size=settings.db_min_conn_size,
max_size=settings.db_max_conn_size,
Expand All @@ -91,36 +80,51 @@ def get_pool(self) -> ConnectionPool:
num_workers=settings.db_num_workers,
open=True,
)
return self.pool

def open(self) -> None:
"""Open database pool connection."""
self.get_pool()
def get_pool(self) -> ConnectionPool:
"""Get Database Pool."""
return self.pool or self._pool

def close(self) -> None:
"""Close database pool connection."""
if self.pool is not None:
self.pool.close()
if self._pool is not None:
self._pool.close()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only close pool if we created it


def __enter__(self) -> Any:
"""Enter used for context."""
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Exit used for context."""
self.close()

@contextlib.contextmanager
def connect(self) -> Connection:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connect() should be use in a context manager

with PgstacDb(dsn=...) as db:
    with db.connect() as conn:
         ...

"""Return database connection."""
pool = self.get_pool()
if self.connection is None:
self.connection = pool.getconn()
self.connection.autocommit = True

conn: Connection = None
try:
conn = pool.getconn()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could also switch to

with pool.connection() as conn:

conn.autocommit = True
if self.debug:
self.connection.add_notice_handler(pg_notice_handler)
self.connection.execute(
conn.add_notice_handler(pg_notice_handler)
conn.execute(
"SET CLIENT_MIN_MESSAGES TO NOTICE;",
prepare=False,
)
if self.use_queue:
self.connection.execute(
conn.execute(
"SET pgstac.use_queue TO TRUE;",
prepare=False,
)
atexit.register(self.disconnect)
self.connection.execute(

conn.execute(
"""
SELECT
CASE
Expand All @@ -138,54 +142,25 @@ def connect(self) -> Connection:
""",
prepare=False,
)
return self.connection
with conn:
yield conn

finally:
if conn:
pool.putconn(conn)

def wait(self) -> None:
"""Block until database connection is ready."""
cnt: int = 0
while cnt < 60:
try:
self.connect()
self.query("SELECT 1;")
return None
except psycopg.errors.OperationalError:
time.sleep(1)
cnt += 1
raise psycopg.errors.CannotConnectNow

def disconnect(self) -> None:
"""Disconnect from database."""
try:
if self.connection is not None:
if self.commit_on_exit:
self.connection.commit()
else:
self.connection.rollback()
except Exception:
pass
try:
if self.pool is not None and self.connection is not None:
self.pool.putconn(self.connection)
except Exception:
pass

self.connection = None
self.pool = None

def __enter__(self) -> Any:
"""Enter used for context."""
self.connect()
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Exit used for context."""
self.disconnect()

@retry(
stop=stop_after_attempt(settings.db_retries),
retry=retry_if_exception_type(psycopg.errors.OperationalError),
Expand All @@ -198,30 +173,27 @@ def query(
row_factory: psycopg.rows.BaseRowFactory = psycopg.rows.tuple_row,
) -> Generator:
"""Query the database with parameters."""
conn = self.connect()
try:
with conn.cursor(row_factory=row_factory) as cursor:
if args is None:
rows = cursor.execute(query, prepare=False)
else:
rows = cursor.execute(query, args)
if rows:
for row in rows:
yield row
else:
yield None
except psycopg.errors.OperationalError as e:
# If we get an operational error check the pool and retry
logger.warning(f"OPERATIONAL ERROR: {e}")
if self.pool is None:
self.get_pool()
else:
self.pool.check()
raise e
except psycopg.errors.DatabaseError as e:
if conn is not None:
conn.rollback()
raise e
with self.connect() as conn:
try:
with conn.cursor(row_factory=row_factory) as cursor:
if args is None:
rows = cursor.execute(query, prepare=False)
else:
rows = cursor.execute(query, args)
if rows:
for row in rows:
yield row
else:
yield None
except psycopg.errors.OperationalError as e:
# If we get an operational error check the pool and retry
logger.warning(f"OPERATIONAL ERROR: {e}")
self._pool.check()
raise e
except psycopg.errors.DatabaseError as e:
if conn is not None:
conn.rollback()
raise e

def query_one(self, *args: Any, **kwargs: Any) -> Union[Tuple, str, None]:
"""Return results from a query that returns a single row."""
Expand All @@ -238,10 +210,9 @@ def query_one(self, *args: Any, **kwargs: Any) -> Union[Tuple, str, None]:

def run_queued(self) -> str:
try:
self.connect().execute("""
CALL run_queued_queries();
""")
return "Ran Queued Queries"
with self.connect() as conn:
conn.execute("CALL run_queued_queries();")
return "Ran Queued Queries"
except Exception as e:
return f"Error Running Queued Queries: {e}"

Expand All @@ -262,8 +233,6 @@ def version(self) -> Optional[str]:
return version
except psycopg.errors.UndefinedTable:
logger.debug("PgSTAC is not installed.")
if self.connection is not None:
self.connection.rollback()
return None

@property
Expand All @@ -280,13 +249,13 @@ def pg_version(self) -> str:
if isinstance(version, str):
if int(version) < 130000:
major, minor, patch = tuple(
map(int, [version[i:i + 2] for i in range(0, len(version), 2)]),
map(int, [version[i : i + 2] for i in range(0, len(version), 2)]),
)
raise Exception(f"PgSTAC requires PostgreSQL 13+, current version is: {major}.{minor}.{patch}") # noqa: E501
raise Exception(
f"PgSTAC requires PostgreSQL 13+, current version is: {major}.{minor}.{patch}",
) # noqa: E501
return version
else:
if self.connection is not None:
self.connection.rollback()
raise Exception("Could not find PG version.")

def func(self, function_name: str, *args: Any) -> Generator:
Expand Down
Loading
Loading