diff --git a/src/pypgstac/examples/load_queryables_example.py b/src/pypgstac/examples/load_queryables_example.py index 61bd8644..c242492a 100644 --- a/src/pypgstac/examples/load_queryables_example.py +++ b/src/pypgstac/examples/load_queryables_example.py @@ -16,7 +16,10 @@ def load_for_specific_collections( - cli, sample_file, collection_ids, delete_missing=False, + cli, + sample_file, + collection_ids, + delete_missing=False, ): """Load queryables for specific collections. @@ -27,7 +30,9 @@ def load_for_specific_collections( delete_missing: If True, delete properties not present in the file """ cli.load_queryables( - str(sample_file), collection_ids=collection_ids, delete_missing=delete_missing, + str(sample_file), + collection_ids=collection_ids, + delete_missing=delete_missing, ) @@ -57,7 +62,10 @@ def main(): # Example of loading for specific collections with delete_missing=True # This will delete properties not present in the file, but only for the specified collections load_for_specific_collections( - cli, sample_file, ["landsat-8", "sentinel-2"], delete_missing=True, + cli, + sample_file, + ["landsat-8", "sentinel-2"], + delete_missing=True, ) diff --git a/src/pypgstac/pyproject.toml b/src/pypgstac/pyproject.toml index d6322aa0..32b841a9 100644 --- a/src/pypgstac/pyproject.toml +++ b/src/pypgstac/pyproject.toml @@ -99,10 +99,9 @@ select = [ "PLE", # "PLR", "PLW", - "COM", # flake8-commas ] ignore = [ - # "E501", # line too long, handled by black + "E501", # line too long, handled by black "B008", # do not perform function calls in argument defaults "C901", # too complex "B905", diff --git a/src/pypgstac/src/pypgstac/__init__.py b/src/pypgstac/src/pypgstac/__init__.py index 9886933d..4f4e0438 100644 --- a/src/pypgstac/src/pypgstac/__init__.py +++ b/src/pypgstac/src/pypgstac/__init__.py @@ -1,4 +1,5 @@ """pyPgSTAC Version.""" + from pypgstac.version import __version__ __all__ = ["__version__"] diff --git a/src/pypgstac/src/pypgstac/db.py b/src/pypgstac/src/pypgstac/db.py index 001ec55b..3fead502 100644 --- a/src/pypgstac/src/pypgstac/db.py +++ b/src/pypgstac/src/pypgstac/db.py @@ -1,7 +1,9 @@ """Base library for database interaction with PgSTAC.""" -import atexit + +import contextlib import logging import time +from dataclasses import dataclass, field from types import TracebackType from typing import Any, Generator, List, Optional, Tuple, Type, Union @@ -52,37 +54,24 @@ class Settings(BaseSettings): settings = Settings() +@dataclass class PgstacDB: """Base class for interacting with PgSTAC Database.""" - def __init__( - self, - dsn: Optional[str] = "", - pool: Optional[ConnectionPool] = None, - connection: Optional[Connection] = None, - commit_on_exit: bool = True, - debug: bool = False, - use_queue: bool = False, - ) -> None: - """Initialize Database.""" - self.dsn: str - if dsn is not None: - self.dsn = dsn - else: - self.dsn = "" - self.pool = pool - self.connection = connection - self.commit_on_exit = commit_on_exit - self.initial_version = "0.1.9" - self.debug = debug - self.use_queue = use_queue - if self.debug: - logging.basicConfig(level=logging.DEBUG) + dsn: str + commit_on_exit: bool = True + debug: bool = False + use_queue: bool = False - def get_pool(self) -> ConnectionPool: - """Get Database Pool.""" - if self.pool is None: - self.pool = ConnectionPool( + pool: ConnectionPool = field(default=None) + + initial_version: str = field(init=False, default="0.1.9") + + _pool: ConnectionPool = field(init=False) + + def __post_init__(self): + if not self.pool: + self._pool = ConnectionPool( conninfo=self.dsn, min_size=settings.db_min_conn_size, max_size=settings.db_max_conn_size, @@ -91,36 +80,51 @@ def get_pool(self) -> ConnectionPool: num_workers=settings.db_num_workers, open=True, ) - return self.pool - def open(self) -> None: - """Open database pool connection.""" - self.get_pool() + def get_pool(self) -> ConnectionPool: + """Get Database Pool.""" + return self.pool or self._pool def close(self) -> None: """Close database pool connection.""" - if self.pool is not None: - self.pool.close() + if self._pool is not None: + self._pool.close() + + def __enter__(self) -> Any: + """Enter used for context.""" + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + """Exit used for context.""" + self.close() + @contextlib.contextmanager def connect(self) -> Connection: """Return database connection.""" pool = self.get_pool() - if self.connection is None: - self.connection = pool.getconn() - self.connection.autocommit = True + + conn: Connection = None + try: + conn = pool.getconn() + conn.autocommit = True if self.debug: - self.connection.add_notice_handler(pg_notice_handler) - self.connection.execute( + conn.add_notice_handler(pg_notice_handler) + conn.execute( "SET CLIENT_MIN_MESSAGES TO NOTICE;", prepare=False, ) if self.use_queue: - self.connection.execute( + conn.execute( "SET pgstac.use_queue TO TRUE;", prepare=False, ) - atexit.register(self.disconnect) - self.connection.execute( + + conn.execute( """ SELECT CASE @@ -138,14 +142,18 @@ def connect(self) -> Connection: """, prepare=False, ) - return self.connection + with conn: + yield conn + + finally: + if conn: + pool.putconn(conn) def wait(self) -> None: """Block until database connection is ready.""" cnt: int = 0 while cnt < 60: try: - self.connect() self.query("SELECT 1;") return None except psycopg.errors.OperationalError: @@ -153,39 +161,6 @@ def wait(self) -> None: cnt += 1 raise psycopg.errors.CannotConnectNow - def disconnect(self) -> None: - """Disconnect from database.""" - try: - if self.connection is not None: - if self.commit_on_exit: - self.connection.commit() - else: - self.connection.rollback() - except Exception: - pass - try: - if self.pool is not None and self.connection is not None: - self.pool.putconn(self.connection) - except Exception: - pass - - self.connection = None - self.pool = None - - def __enter__(self) -> Any: - """Enter used for context.""" - self.connect() - return self - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - """Exit used for context.""" - self.disconnect() - @retry( stop=stop_after_attempt(settings.db_retries), retry=retry_if_exception_type(psycopg.errors.OperationalError), @@ -198,30 +173,27 @@ def query( row_factory: psycopg.rows.BaseRowFactory = psycopg.rows.tuple_row, ) -> Generator: """Query the database with parameters.""" - conn = self.connect() - try: - with conn.cursor(row_factory=row_factory) as cursor: - if args is None: - rows = cursor.execute(query, prepare=False) - else: - rows = cursor.execute(query, args) - if rows: - for row in rows: - yield row - else: - yield None - except psycopg.errors.OperationalError as e: - # If we get an operational error check the pool and retry - logger.warning(f"OPERATIONAL ERROR: {e}") - if self.pool is None: - self.get_pool() - else: - self.pool.check() - raise e - except psycopg.errors.DatabaseError as e: - if conn is not None: - conn.rollback() - raise e + with self.connect() as conn: + try: + with conn.cursor(row_factory=row_factory) as cursor: + if args is None: + rows = cursor.execute(query, prepare=False) + else: + rows = cursor.execute(query, args) + if rows: + for row in rows: + yield row + else: + yield None + except psycopg.errors.OperationalError as e: + # If we get an operational error check the pool and retry + logger.warning(f"OPERATIONAL ERROR: {e}") + self._pool.check() + raise e + except psycopg.errors.DatabaseError as e: + if conn is not None: + conn.rollback() + raise e def query_one(self, *args: Any, **kwargs: Any) -> Union[Tuple, str, None]: """Return results from a query that returns a single row.""" @@ -238,10 +210,9 @@ def query_one(self, *args: Any, **kwargs: Any) -> Union[Tuple, str, None]: def run_queued(self) -> str: try: - self.connect().execute(""" - CALL run_queued_queries(); - """) - return "Ran Queued Queries" + with self.connect() as conn: + conn.execute("CALL run_queued_queries();") + return "Ran Queued Queries" except Exception as e: return f"Error Running Queued Queries: {e}" @@ -262,8 +233,6 @@ def version(self) -> Optional[str]: return version except psycopg.errors.UndefinedTable: logger.debug("PgSTAC is not installed.") - if self.connection is not None: - self.connection.rollback() return None @property @@ -280,13 +249,13 @@ def pg_version(self) -> str: if isinstance(version, str): if int(version) < 130000: major, minor, patch = tuple( - map(int, [version[i:i + 2] for i in range(0, len(version), 2)]), + map(int, [version[i : i + 2] for i in range(0, len(version), 2)]), ) - raise Exception(f"PgSTAC requires PostgreSQL 13+, current version is: {major}.{minor}.{patch}") # noqa: E501 + raise Exception( + f"PgSTAC requires PostgreSQL 13+, current version is: {major}.{minor}.{patch}", + ) # noqa: E501 return version else: - if self.connection is not None: - self.connection.rollback() raise Exception("Could not find PG version.") def func(self, function_name: str, *args: Any) -> Generator: diff --git a/src/pypgstac/src/pypgstac/load.py b/src/pypgstac/src/pypgstac/load.py index 3d64f7f7..14cc66eb 100644 --- a/src/pypgstac/src/pypgstac/load.py +++ b/src/pypgstac/src/pypgstac/load.py @@ -5,7 +5,7 @@ import logging import sys import time -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime from enum import Enum from pathlib import Path @@ -148,15 +148,13 @@ def read_json(file: Union[Path, str, Iterator[Any]] = "stdin") -> Iterable: yield orjson.loads(line) +@dataclass class Loader: """Utilities for loading data.""" db: PgstacDB - _partition_cache: Dict[str, Partition] - def __init__(self, db: PgstacDB): - self.db = db - self._partition_cache: Dict[str, Partition] = {} + _partition_cache: Dict[str, Partition] = field(init=False, default_factory=dict) def check_version(self) -> None: db_version = self.db.version @@ -205,60 +203,61 @@ def load_collections( if file is None: file = "stdin" - conn = self.db.connect() - with conn.cursor() as cur: - with conn.transaction(): - cur.execute( - """ - DROP TABLE IF EXISTS tmp_collections; - CREATE TEMP TABLE tmp_collections - (content jsonb) ON COMMIT DROP; - """, - ) - with cur.copy("COPY tmp_collections (content) FROM stdin;") as copy: - for collection in read_json(file): - copy.write_row((orjson.dumps(collection).decode(),)) - if insert_mode in ( - None, - Methods.insert, - ): - cur.execute( - """ - INSERT INTO collections (content) - SELECT content FROM tmp_collections; - """, - ) - logger.debug(cur.statusmessage) - logger.debug(f"Rows affected: {cur.rowcount}") - elif insert_mode in ( - Methods.insert_ignore, - Methods.ignore, - ): - cur.execute( - """ - INSERT INTO collections (content) - SELECT content FROM tmp_collections - ON CONFLICT DO NOTHING; - """, - ) - logger.debug(cur.statusmessage) - logger.debug(f"Rows affected: {cur.rowcount}") - elif insert_mode == Methods.upsert: + + with self.db.connect() as conn: + with conn.cursor() as cur: + with conn.transaction(): cur.execute( """ - INSERT INTO collections (content) - SELECT content FROM tmp_collections - ON CONFLICT (id) DO - UPDATE SET content=EXCLUDED.content; + DROP TABLE IF EXISTS tmp_collections; + CREATE TEMP TABLE tmp_collections + (content jsonb) ON COMMIT DROP; """, ) - logger.debug(cur.statusmessage) - logger.debug(f"Rows affected: {cur.rowcount}") - else: - raise Exception( - "Available modes are insert, ignore, and upsert." - f"You entered {insert_mode}.", - ) + with cur.copy("COPY tmp_collections (content) FROM stdin;") as copy: + for collection in read_json(file): + copy.write_row((orjson.dumps(collection).decode(),)) + if insert_mode in ( + None, + Methods.insert, + ): + cur.execute( + """ + INSERT INTO collections (content) + SELECT content FROM tmp_collections; + """, + ) + logger.debug(cur.statusmessage) + logger.debug(f"Rows affected: {cur.rowcount}") + elif insert_mode in ( + Methods.insert_ignore, + Methods.ignore, + ): + cur.execute( + """ + INSERT INTO collections (content) + SELECT content FROM tmp_collections + ON CONFLICT DO NOTHING; + """, + ) + logger.debug(cur.statusmessage) + logger.debug(f"Rows affected: {cur.rowcount}") + elif insert_mode == Methods.upsert: + cur.execute( + """ + INSERT INTO collections (content) + SELECT content FROM tmp_collections + ON CONFLICT (id) DO + UPDATE SET content=EXCLUDED.content; + """, + ) + logger.debug(cur.statusmessage) + logger.debug(f"Rows affected: {cur.rowcount}") + else: + raise Exception( + "Available modes are insert, ignore, and upsert." + f"You entered {insert_mode}.", + ) @retry( stop=stop_after_attempt(5), @@ -276,186 +275,192 @@ def load_partition( insert_mode: Optional[Methods] = Methods.insert, ) -> None: """Load items data for a single partition.""" - conn = self.db.connect() - t = time.perf_counter() + with self.db.connect() as conn: + t = time.perf_counter() - logger.debug(f"Loading data for partition: {partition}.") - with conn.cursor() as cur: - if partition.requires_update: - with conn.transaction(): - cur.execute( - """ - SELECT check_partition( - %s, - tstzrange(%s, %s, '[]'), - tstzrange(%s, %s, '[]') - ); - """, - ( - partition.collection, - partition.datetime_range_min, - partition.datetime_range_max, - partition.end_datetime_range_min, - partition.end_datetime_range_max, - ), - ) - - logger.debug( - f"Adding or updating partition {partition.name} " - f"took {time.perf_counter() - t}s", - ) - partition.requires_update = False - else: - logger.debug(f"Partition {partition.name} does not require an update.") - - with conn.transaction(): - t = time.perf_counter() - if insert_mode in ( - None, - Methods.insert, - ): - with cur.copy( - sql.SQL( + logger.debug(f"Loading data for partition: {partition}.") + with conn.cursor() as cur: + if partition.requires_update: + with conn.transaction(): + cur.execute( """ - COPY {} - (id, collection, datetime, - end_datetime, geometry, - content, private) - FROM stdin; - """, - ).format(sql.Identifier(partition.name)), - ) as copy: - for item in items: - item.pop("partition") - copy.write_row( - ( - item["id"], - item["collection"], - item["datetime"], - item["end_datetime"], - item["geometry"], - item["content"], - item.get("private", None), - ), - ) - logger.debug(cur.statusmessage) - logger.debug(f"Rows affected: {cur.rowcount}") - elif insert_mode in ( - Methods.insert_ignore, - Methods.upsert, - Methods.delsert, - Methods.ignore, - ): - cur.execute( - """ - DROP TABLE IF EXISTS items_ingest_temp; - CREATE TEMP TABLE items_ingest_temp - ON COMMIT DROP AS SELECT * FROM items LIMIT 0; + SELECT check_partition( + %s, + tstzrange(%s, %s, '[]'), + tstzrange(%s, %s, '[]') + ); """, - ) - with cur.copy( - """ - COPY items_ingest_temp - (id, collection, datetime, - end_datetime, geometry, - content, private) - FROM stdin; - """, - ) as copy: - for item in items: - item.pop("partition") - copy.write_row( - ( - item["id"], - item["collection"], - item["datetime"], - item["end_datetime"], - item["geometry"], - item["content"], - item.get("private", None), - ), - ) - logger.debug(cur.statusmessage) - logger.debug(f"Copied rows: {cur.rowcount}") + ( + partition.collection, + partition.datetime_range_min, + partition.datetime_range_max, + partition.end_datetime_range_min, + partition.end_datetime_range_max, + ), + ) - cur.execute( - sql.SQL( - """ - LOCK TABLE ONLY {} IN EXCLUSIVE MODE; - """, - ).format(sql.Identifier(partition.name)), + logger.debug( + f"Adding or updating partition {partition.name} " + f"took {time.perf_counter() - t}s", + ) + partition.requires_update = False + else: + logger.debug( + f"Partition {partition.name} does not require an update.", ) + + with conn.transaction(): + t = time.perf_counter() if insert_mode in ( - Methods.ignore, - Methods.insert_ignore, + None, + Methods.insert, ): - cur.execute( + with cur.copy( sql.SQL( """ - INSERT INTO {} - SELECT * - FROM items_ingest_temp ON CONFLICT DO NOTHING; + COPY {} + (id, collection, datetime, + end_datetime, geometry, + content, private) + FROM stdin; """, ).format(sql.Identifier(partition.name)), - ) + ) as copy: + for item in items: + item.pop("partition") + copy.write_row( + ( + item["id"], + item["collection"], + item["datetime"], + item["end_datetime"], + item["geometry"], + item["content"], + item.get("private", None), + ), + ) logger.debug(cur.statusmessage) logger.debug(f"Rows affected: {cur.rowcount}") - elif insert_mode == Methods.upsert: + elif insert_mode in ( + Methods.insert_ignore, + Methods.upsert, + Methods.delsert, + Methods.ignore, + ): cur.execute( - sql.SQL( - """ - INSERT INTO {} AS t SELECT * FROM items_ingest_temp - ON CONFLICT (id) DO UPDATE - SET - datetime = EXCLUDED.datetime, - end_datetime = EXCLUDED.end_datetime, - geometry = EXCLUDED.geometry, - collection = EXCLUDED.collection, - content = EXCLUDED.content - WHERE t IS DISTINCT FROM EXCLUDED - ; + """ + DROP TABLE IF EXISTS items_ingest_temp; + CREATE TEMP TABLE items_ingest_temp + ON COMMIT DROP AS SELECT * FROM items LIMIT 0; """, - ).format(sql.Identifier(partition.name)), ) + with cur.copy( + """ + COPY items_ingest_temp + (id, collection, datetime, + end_datetime, geometry, + content, private) + FROM stdin; + """, + ) as copy: + for item in items: + item.pop("partition") + copy.write_row( + ( + item["id"], + item["collection"], + item["datetime"], + item["end_datetime"], + item["geometry"], + item["content"], + item.get("private", None), + ), + ) logger.debug(cur.statusmessage) - logger.debug(f"Rows affected: {cur.rowcount}") - elif insert_mode == Methods.delsert: + logger.debug(f"Copied rows: {cur.rowcount}") + cur.execute( sql.SQL( """ - WITH deletes AS ( - DELETE FROM items i USING items_ingest_temp s - WHERE - i.id = s.id - AND i.collection = s.collection - ) - INSERT INTO {} AS t SELECT * FROM items_ingest_temp - ON CONFLICT (id) DO UPDATE - SET - datetime = EXCLUDED.datetime, - end_datetime = EXCLUDED.end_datetime, - geometry = EXCLUDED.geometry, - collection = EXCLUDED.collection, - content = EXCLUDED.content - WHERE t IS DISTINCT FROM EXCLUDED - ; + LOCK TABLE ONLY {} IN EXCLUSIVE MODE; """, ).format(sql.Identifier(partition.name)), ) - logger.debug(cur.statusmessage) - logger.debug(f"Rows affected: {cur.rowcount}") - else: - raise Exception( - "Available modes are insert, ignore, upsert, and delsert." - f"You entered {insert_mode}.", + if insert_mode in ( + Methods.ignore, + Methods.insert_ignore, + ): + cur.execute( + sql.SQL( + """ + INSERT INTO {} + SELECT * + FROM items_ingest_temp ON CONFLICT DO NOTHING; + """, + ).format(sql.Identifier(partition.name)), + ) + logger.debug(cur.statusmessage) + logger.debug(f"Rows affected: {cur.rowcount}") + elif insert_mode == Methods.upsert: + cur.execute( + sql.SQL( + """ + INSERT INTO {} AS t SELECT * FROM items_ingest_temp + ON CONFLICT (id) DO UPDATE + SET + datetime = EXCLUDED.datetime, + end_datetime = EXCLUDED.end_datetime, + geometry = EXCLUDED.geometry, + collection = EXCLUDED.collection, + content = EXCLUDED.content + WHERE t IS DISTINCT FROM EXCLUDED + ; + """, + ).format(sql.Identifier(partition.name)), + ) + logger.debug(cur.statusmessage) + logger.debug(f"Rows affected: {cur.rowcount}") + elif insert_mode == Methods.delsert: + cur.execute( + sql.SQL( + """ + WITH deletes AS ( + DELETE FROM items i USING items_ingest_temp s + WHERE + i.id = s.id + AND i.collection = s.collection + ) + INSERT INTO {} AS t SELECT * FROM items_ingest_temp + ON CONFLICT (id) DO UPDATE + SET + datetime = EXCLUDED.datetime, + end_datetime = EXCLUDED.end_datetime, + geometry = EXCLUDED.geometry, + collection = EXCLUDED.collection, + content = EXCLUDED.content + WHERE t IS DISTINCT FROM EXCLUDED + ; + """, + ).format(sql.Identifier(partition.name)), + ) + logger.debug(cur.statusmessage) + logger.debug(f"Rows affected: {cur.rowcount}") + else: + raise Exception( + "Available modes are insert, ignore, upsert, and delsert." + f"You entered {insert_mode}.", + ) + logger.debug("Updating Partition Stats") + cur.execute( + "SELECT update_partition_stats_q(%s);", + (partition.name,), ) - logger.debug("Updating Partition Stats") - cur.execute("SELECT update_partition_stats_q(%s);", (partition.name,)) - logger.debug(cur.statusmessage) - logger.debug(f"Rows affected: {cur.rowcount}") - logger.debug( - f"Copying data for {partition} took {time.perf_counter() - t} seconds", - ) + logger.debug(cur.statusmessage) + logger.debug(f"Rows affected: {cur.rowcount}") + + logger.debug( + f"Copying data for {partition} took {time.perf_counter() - t} seconds", + ) def _partition_update(self, item: Dict[str, Any]) -> str: """Update the cached partition with the item information and return the name. @@ -604,6 +609,7 @@ def load_items( if file is None: file = "stdin" t = time.perf_counter() + self._partition_cache = {} if dehydrated and isinstance(file, str): diff --git a/src/pypgstac/src/pypgstac/migrate.py b/src/pypgstac/src/pypgstac/migrate.py index ea4c1244..e8bfe3f5 100644 --- a/src/pypgstac/src/pypgstac/migrate.py +++ b/src/pypgstac/src/pypgstac/migrate.py @@ -1,9 +1,11 @@ """Utilities to help migrate pgstac schema.""" + import glob import logging import os import re from collections import defaultdict +from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional from smart_open import open @@ -36,7 +38,8 @@ def __init__(self, path: str, f: str, t: str) -> None: def parse_filename(self, filename: str) -> List[str]: """Get version numbers from filename.""" filename = os.path.splitext(os.path.basename(filename))[0].replace( - "pgstac.", "", + "pgstac.", + "", ) return filename.split("-") @@ -96,7 +99,7 @@ def migrations(self) -> List[str]: def get_sql(file: str) -> str: """Get sql from a file as a string.""" sqlstrs = [] - file = re.sub("[0-9]+[.][0-9]+[.][0-9]+-dev","unreleased",file) + file = re.sub("[0-9]+[.][0-9]+[.][0-9]+-dev", "unreleased", file) fp = os.path.join(migrations_dir, file) file_handle: Any = open(fp) @@ -105,20 +108,19 @@ def get_sql(file: str) -> str: return "\n".join(sqlstrs) +@dataclass class Migrate: """Utilities for migrating pgstac database.""" - def __init__(self, db: PgstacDB, schema: str = "pgstac"): - """Prepare for migration.""" - self.db = db - self.schema = schema + db: PgstacDB + schema: str = "pgstac" def run_migration(self, toversion: Optional[str] = None) -> str: """Migrate a pgstac database to current version.""" if toversion is None: toversion = __version__ files = [] - if re.search(r"-dev$",toversion): + if re.search(r"-dev$", toversion): logger.info("using unreleased version") toversion = "unreleased" @@ -126,7 +128,7 @@ def run_migration(self, toversion: Optional[str] = None) -> str: map( int, [ - self.db.pg_version[i:i + 2] + self.db.pg_version[i : i + 2] for i in range(0, len(self.db.pg_version), 2) ], ), @@ -147,18 +149,17 @@ def run_migration(self, toversion: Optional[str] = None) -> str: if len(files) < 1: raise Exception("Could not find migration files") - conn = self.db.connect() - - with conn.cursor() as cur: - conn.autocommit = False - for file in files: - logger.debug(f"Running migration file {file}.") - migration_sql = get_sql(file) - cur.execute(migration_sql) - logger.debug(cur.statusmessage) - logger.debug(cur.rowcount) - - logger.debug(f"Database migrated to {toversion}") + with self.db.connect() as conn: + with conn.cursor() as cur: + conn.autocommit = False + for file in files: + logger.debug(f"Running migration file {file}.") + migration_sql = get_sql(file) + cur.execute(migration_sql) + logger.debug(cur.statusmessage) + logger.debug(cur.rowcount) + + logger.debug(f"Database migrated to {toversion}") newversion = self.db.version if conn is not None: diff --git a/src/pypgstac/src/pypgstac/pypgstac.py b/src/pypgstac/src/pypgstac/pypgstac.py index e0720850..d4357fbc 100644 --- a/src/pypgstac/src/pypgstac/pypgstac.py +++ b/src/pypgstac/src/pypgstac/pypgstac.py @@ -6,6 +6,7 @@ import fire import orjson +from psycopg.rows import tuple_row from smart_open import open from pypgstac.db import PgstacDB @@ -28,7 +29,9 @@ def __init__( sys.exit(0) self.dsn = dsn - self._db = PgstacDB(dsn=dsn, debug=debug, use_queue=usequeue) + self.debug = debug + self.usequeue = usequeue + if debug: logging.basicConfig(level=logging.DEBUG) sys.tracebacklimit = 1000 @@ -41,25 +44,29 @@ def initversion(self) -> str: @property def version(self) -> Optional[str]: """Get PgSTAC version installed on database.""" - return self._db.version + with PgstacDB(dsn=self.dsn, debug=self.debug, use_queue=self.usequeue) as db: + return db.version @property def pg_version(self) -> str: """Get PostgreSQL server version installed on database.""" - return self._db.pg_version + with PgstacDB(dsn=self.dsn, debug=self.debug, use_queue=self.usequeue) as db: + return db.pg_version def pgready(self) -> None: """Wait for a pgstac database to accept connections.""" - self._db.wait() + with PgstacDB(dsn=self.dsn, debug=self.debug, use_queue=self.usequeue) as db: + db.wait() def search(self, query: str) -> str: """Search PgSTAC.""" - return self._db.search(query) + with PgstacDB(dsn=self.dsn, debug=self.debug, use_queue=self.usequeue) as db: + db.search(query) def migrate(self, toversion: Optional[str] = None) -> str: """Migrate PgSTAC Database.""" - migrator = Migrate(self._db) - return migrator.run_migration(toversion=toversion) + with PgstacDB(dsn=self.dsn, debug=self.debug, use_queue=self.usequeue) as db: + return Migrate(db).run_migration(toversion=toversion) def load( self, @@ -70,55 +77,58 @@ def load( chunksize: Optional[int] = 10000, ) -> None: """Load collections or items into PgSTAC.""" - loader = Loader(db=self._db) - if table == "collections": - loader.load_collections(file, method) - if table == "items": - loader.load_items(file, method, dehydrated, chunksize) + with PgstacDB(dsn=self.dsn, debug=self.debug, use_queue=self.usequeue) as db: + loader = Loader(db=db) + if table == "collections": + loader.load_collections(file, method) + if table == "items": + loader.load_items(file, method, dehydrated, chunksize) def runqueue(self) -> str: - return self._db.run_queued() + with PgstacDB(dsn=self.dsn, debug=self.debug, use_queue=self.usequeue) as db: + return db.run_queued() def loadextensions(self) -> None: - conn = self._db.connect() - - with conn.cursor() as cur: - cur.execute( - """ - INSERT INTO stac_extensions (url) - SELECT DISTINCT - substring( - jsonb_array_elements_text(content->'stac_extensions') FROM E'^[^#]*' - ) - FROM collections - ON CONFLICT DO NOTHING; - """, - ) - conn.commit() - - urls = self._db.query( - """ - SELECT url FROM stac_extensions WHERE content IS NULL; - """, - ) - if urls: - for u in urls: - url = u[0] - try: - with open(url, "r") as f: - content = f.read() - self._db.query( - """ - UPDATE pgstac.stac_extensions - SET content=%s - WHERE url=%s - ; - """, - [content, url], + with PgstacDB(dsn=self.dsn, debug=self.debug, use_queue=self.usequeue) as db: + with db.connect() as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO stac_extensions (url) + SELECT DISTINCT + substring( + jsonb_array_elements_text(content->'stac_extensions') FROM E'^[^#]*' ) - conn.commit() - except Exception: - pass + FROM collections + ON CONFLICT DO NOTHING; + """, + ) + conn.commit() + + with conn.cursor(row_factory=tuple_row) as cur: + query = """ + SELECT url FROM stac_extensions WHERE content IS NULL; + """ + urls = cur.execute(query, prepare=False) + for u in urls: + url = u[0] + try: + with open(url, "r") as f: + content = f.read() + + cur.execute( + """ + UPDATE pgstac.stac_extensions + SET content=%s + WHERE url=%s + ; + """, + [content, url], + ) + + conn.commit() + except Exception: + pass def load_queryables( self, @@ -155,146 +165,149 @@ def load_queryables( if not properties: raise ValueError("No properties found in queryables definition") - conn = self._db.connect() - with conn.cursor() as cur: - with conn.transaction(): - # Insert each property as a queryable - for name, definition in properties.items(): - # Skip core fields that are already indexed - if name in ( - "id", - "geometry", - "datetime", - "end_datetime", - "collection", - ): - continue - - # Determine property wrapper based on type - property_wrapper = "to_text" # default - if definition.get("type") == "number": - property_wrapper = "to_float" - elif definition.get("type") == "integer": - property_wrapper = "to_int" - elif definition.get("format") == "date-time": - property_wrapper = "to_tstz" - elif definition.get("type") == "array": - property_wrapper = "to_text_array" - - # Determine if this field should be indexed - property_index_type = None - if index_fields and name in index_fields: - property_index_type = "BTREE" - - # First delete any existing queryable with the same name - if not collection_ids: - # If no collection_ids specified, delete queryables - # with NULL collection_ids - cur.execute( - """ - DELETE FROM queryables - WHERE name = %s AND collection_ids IS NULL - """, - [name], - ) - else: - # Delete queryables with matching name and collection_ids - cur.execute( - """ - DELETE FROM queryables - WHERE name = %s AND collection_ids = %s::text[] - """, - [name, collection_ids], - ) - - # Also delete queryables with NULL collection_ids - cur.execute( - """ - DELETE FROM queryables - WHERE name = %s AND collection_ids IS NULL - """, - [name], - ) - - # Then insert the new queryable - cur.execute( - """ - INSERT INTO queryables - (name, collection_ids, definition, property_wrapper, - property_index_type) - VALUES (%s, %s, %s, %s, %s) - """, - [ - name, - collection_ids, - orjson.dumps(definition).decode(), - property_wrapper, - property_index_type, - ], - ) - - # If delete_missing is True, - # delete all queryables that were not in the file - if delete_missing: - # Get the list of property names from the file - property_names = list(properties.keys()) - - # Skip core fields that are already indexed - core_fields = [ - "id", - "geometry", - "datetime", - "end_datetime", - "collection", - ] - property_names = [ - name for name in property_names if name not in core_fields - ] - - if not property_names: - # If no valid properties, don't delete anything - pass - elif not collection_ids: - # If no collection_ids specified, - # delete queryables with NULL collection_ids - # that are not in the property_names list - placeholders = ", ".join(["%s"] * len(property_names)) - core_placeholders = ", ".join(["%s"] * len(core_fields)) - - # Build the query with proper placeholders - query = f""" - DELETE FROM queryables - WHERE collection_ids IS NULL - AND name NOT IN ({placeholders}) - AND name NOT IN ({core_placeholders}) - """ - - # Flatten the parameters - params = property_names + core_fields - - cur.execute(query, params) - else: - # Delete queryables with matching collection_ids - # that are not in the property_names list - placeholders = ", ".join(["%s"] * len(property_names)) - core_placeholders = ", ".join(["%s"] * len(core_fields)) - - # Build the query with proper placeholders - query = f""" - DELETE FROM queryables - WHERE collection_ids = %s::text[] - AND name NOT IN ({placeholders}) - AND name NOT IN ({core_placeholders}) - """ - - # Flatten the parameters - params = [collection_ids] + property_names + core_fields - - cur.execute(query, params) - - # Trigger index creation only if index_fields were provided - if index_fields and len(index_fields) > 0: - cur.execute("SELECT maintain_partitions();") + with PgstacDB(dsn=self.dsn, debug=self.debug, use_queue=self.usequeue) as db: + with db.connect() as conn: + with conn.cursor() as cur: + with conn.transaction(): + # Insert each property as a queryable + for name, definition in properties.items(): + # Skip core fields that are already indexed + if name in ( + "id", + "geometry", + "datetime", + "end_datetime", + "collection", + ): + continue + + # Determine property wrapper based on type + property_wrapper = "to_text" # default + if definition.get("type") == "number": + property_wrapper = "to_float" + elif definition.get("type") == "integer": + property_wrapper = "to_int" + elif definition.get("format") == "date-time": + property_wrapper = "to_tstz" + elif definition.get("type") == "array": + property_wrapper = "to_text_array" + + # Determine if this field should be indexed + property_index_type = None + if index_fields and name in index_fields: + property_index_type = "BTREE" + + # First delete any existing queryable with the same name + if not collection_ids: + # If no collection_ids specified, delete queryables + # with NULL collection_ids + cur.execute( + """ + DELETE FROM queryables + WHERE name = %s AND collection_ids IS NULL + """, + [name], + ) + else: + # Delete queryables with matching name and collection_ids + cur.execute( + """ + DELETE FROM queryables + WHERE name = %s AND collection_ids = %s::text[] + """, + [name, collection_ids], + ) + + # Also delete queryables with NULL collection_ids + cur.execute( + """ + DELETE FROM queryables + WHERE name = %s AND collection_ids IS NULL + """, + [name], + ) + + # Then insert the new queryable + cur.execute( + """ + INSERT INTO queryables + (name, collection_ids, definition, property_wrapper, + property_index_type) + VALUES (%s, %s, %s, %s, %s) + """, + [ + name, + collection_ids, + orjson.dumps(definition).decode(), + property_wrapper, + property_index_type, + ], + ) + + # If delete_missing is True, + # delete all queryables that were not in the file + if delete_missing: + # Get the list of property names from the file + property_names = list(properties.keys()) + + # Skip core fields that are already indexed + core_fields = [ + "id", + "geometry", + "datetime", + "end_datetime", + "collection", + ] + property_names = [ + name + for name in property_names + if name not in core_fields + ] + + if not property_names: + # If no valid properties, don't delete anything + pass + elif not collection_ids: + # If no collection_ids specified, + # delete queryables with NULL collection_ids + # that are not in the property_names list + placeholders = ", ".join(["%s"] * len(property_names)) + core_placeholders = ", ".join(["%s"] * len(core_fields)) + + # Build the query with proper placeholders + query = f""" + DELETE FROM queryables + WHERE collection_ids IS NULL + AND name NOT IN ({placeholders}) + AND name NOT IN ({core_placeholders}) + """ + + # Flatten the parameters + params = property_names + core_fields + + cur.execute(query, params) + else: + # Delete queryables with matching collection_ids + # that are not in the property_names list + placeholders = ", ".join(["%s"] * len(property_names)) + core_placeholders = ", ".join(["%s"] * len(core_fields)) + + # Build the query with proper placeholders + query = f""" + DELETE FROM queryables + WHERE collection_ids = %s::text[] + AND name NOT IN ({placeholders}) + AND name NOT IN ({core_placeholders}) + """ + + # Flatten the parameters + params = [collection_ids] + property_names + core_fields + + cur.execute(query, params) + + # Trigger index creation only if index_fields were provided + if index_fields and len(index_fields) > 0: + cur.execute("SELECT maintain_partitions();") def cli() -> fire.Fire: diff --git a/src/pypgstac/tests/conftest.py b/src/pypgstac/tests/conftest.py index d7536799..3b1f9258 100644 --- a/src/pypgstac/tests/conftest.py +++ b/src/pypgstac/tests/conftest.py @@ -1,4 +1,5 @@ """Fixtures for pypgstac tests.""" + import os from typing import Generator @@ -40,11 +41,11 @@ def db() -> Generator: try: conn.execute("DROP DATABASE pypgstactestdb;") conn.execute( - """ + """ CREATE DATABASE pypgstactestdb TEMPLATE pgstac_test_db_template; """, - ) + ) except Exception: pass diff --git a/src/pypgstac/tests/hydration/test_dehydrate_pg.py b/src/pypgstac/tests/hydration/test_dehydrate_pg.py index 0734eead..8646c08b 100644 --- a/src/pypgstac/tests/hydration/test_dehydrate_pg.py +++ b/src/pypgstac/tests/hydration/test_dehydrate_pg.py @@ -36,7 +36,9 @@ def db(self) -> Generator: os.environ["PGDATABASE"] = origdb def dehydrate( - self, base_item: Dict[str, Any], item: Dict[str, Any], + self, + base_item: Dict[str, Any], + item: Dict[str, Any], ) -> Dict[str, Any]: """Dehydrate item using pgstac.""" with self.db() as db: diff --git a/src/pypgstac/tests/hydration/test_hydrate_pg.py b/src/pypgstac/tests/hydration/test_hydrate_pg.py index 7f7ddc05..99e9e5f7 100644 --- a/src/pypgstac/tests/hydration/test_hydrate_pg.py +++ b/src/pypgstac/tests/hydration/test_hydrate_pg.py @@ -1,4 +1,5 @@ """Test Hydration in PgSTAC.""" + import os from contextlib import contextmanager from typing import Any, Dict, Generator @@ -37,7 +38,9 @@ def db(self) -> Generator[PgstacDB, None, None]: os.environ["PGDATABASE"] = origdb def hydrate( - self, base_item: Dict[str, Any], item: Dict[str, Any], + self, + base_item: Dict[str, Any], + item: Dict[str, Any], ) -> Dict[str, Any]: """Hydrate using pgstac.""" with self.db() as db: diff --git a/src/pypgstac/tests/test_queryables.py b/src/pypgstac/tests/test_queryables.py index 31241818..9e1067c0 100644 --- a/src/pypgstac/tests/test_queryables.py +++ b/src/pypgstac/tests/test_queryables.py @@ -223,7 +223,8 @@ def test_maintain_partitions_called_only_with_index_fields(mock_connect): # Check that maintain_partitions was called maintain_calls = [ - call_args for call_args in mock_cursor.execute.call_args_list + call_args + for call_args in mock_cursor.execute.call_args_list if "maintain_partitions" in str(call_args) ] assert len(maintain_calls) == 1 @@ -236,7 +237,8 @@ def test_maintain_partitions_called_only_with_index_fields(mock_connect): # Check that maintain_partitions was not called maintain_calls = [ - call_args for call_args in mock_cursor.execute.call_args_list + call_args + for call_args in mock_cursor.execute.call_args_list if "maintain_partitions" in str(call_args) ] assert len(maintain_calls) == 0 @@ -425,7 +427,8 @@ def test_load_queryables_delete_missing(db: PgstacDB) -> None: def test_load_queryables_delete_missing_with_collections( - db: PgstacDB, loader: Loader, + db: PgstacDB, + loader: Loader, ) -> None: """Test loading queryables with delete_missing=True and specific collections.""" # Load test collections first