Skip to content
Merged
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]

### Added

- Added support for enum queryables [#390](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/390)

### Changed

- Optimize data_loader.py script [#395](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/395)
Expand Down
6 changes: 2 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ APP_HOST ?= 0.0.0.0
EXTERNAL_APP_PORT ?= 8080

ES_APP_PORT ?= 8080
OS_APP_PORT ?= 8082

ES_HOST ?= docker.for.mac.localhost
ES_PORT ?= 9200

OS_APP_PORT ?= 8082
OS_HOST ?= docker.for.mac.localhost
OS_PORT ?= 9202

run_es = docker compose \
run \
-p ${EXTERNAL_APP_PORT}:${ES_APP_PORT} \
Expand Down
14 changes: 13 additions & 1 deletion stac_fastapi/core/stac_fastapi/core/base_database_logic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Base database logic."""

import abc
from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, List, Optional


class BaseDatabaseLogic(abc.ABC):
Expand Down Expand Up @@ -36,6 +36,18 @@ async def delete_item(
"""Delete an item from the database."""
pass

@abc.abstractmethod
async def get_items_mapping(self, collection_id: str) -> Dict[str, Dict[str, Any]]:
"""Get the mapping for the items in the collection."""
pass

@abc.abstractmethod
async def get_items_unique_values(
self, collection_id: str, field_names: Iterable[str], *, limit: int = ...
) -> Dict[str, List[str]]:
"""Get the unique values for the given fields in the collection."""
pass

@abc.abstractmethod
async def create_collection(self, collection: Dict, refresh: bool = False) -> None:
"""Create a collection in the database."""
Expand Down
11 changes: 11 additions & 0 deletions stac_fastapi/core/stac_fastapi/core/extensions/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@
"maximum": 100,
},
}
"""Queryables that are present in all collections."""

OPTIONAL_QUERYABLES: Dict[str, Dict[str, Any]] = {
"platform": {
"$enum": True,
"description": "Satellite platform identifier",
},
}
"""Queryables that are present in some collections."""

ALL_QUERYABLES: Dict[str, Dict[str, Any]] = DEFAULT_QUERYABLES | OPTIONAL_QUERYABLES


class LogicalOp(str, Enum):
Expand Down
3 changes: 2 additions & 1 deletion stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
TokenPaginationExtension,
TransactionExtension,
)
from stac_fastapi.extensions.core.filter import FilterConformanceClasses
from stac_fastapi.extensions.third_party import BulkTransactionExtension
from stac_fastapi.sfeos_helpers.aggregation import EsAsyncBaseAggregationClient
from stac_fastapi.sfeos_helpers.filter import EsAsyncBaseFiltersClient
Expand All @@ -56,7 +57,7 @@
client=EsAsyncBaseFiltersClient(database=database_logic)
)
filter_extension.conformance_classes.append(
"http://www.opengis.net/spec/cql2/1.0/conf/advanced-comparison-operators"
FilterConformanceClasses.ADVANCED_COMPARISON_OPERATORS
)

aggregation_extension = AggregationExtension(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,37 @@ async def get_items_mapping(self, collection_id: str) -> Dict[str, Any]:
except ESNotFoundError:
raise NotFoundError(f"Mapping for index {index_name} not found")

async def get_items_unique_values(
self, collection_id: str, field_names: Iterable[str], *, limit: int = 100
) -> Dict[str, List[str]]:
"""Get the unique values for the given fields in the collection."""
limit_plus_one = limit + 1
index_name = index_alias_by_collection_id(collection_id)

query = await self.client.search(
index=index_name,
body={
"size": 0,
"aggs": {
field: {"terms": {"field": field, "size": limit_plus_one}}
for field in field_names
},
},
)

result: Dict[str, List[str]] = {}
for field, agg in query["aggregations"].items():
if len(agg["buckets"]) > limit:
logger.warning(
"Skipping enum field %s: exceeds limit of %d unique values. "
"Consider excluding this field from enumeration or increase the limit.",
field,
limit,
)
continue
result[field] = [bucket["key"] for bucket in agg["buckets"]]
return result

async def create_collection(self, collection: Collection, **kwargs: Any):
"""Create a single collection in the database.

Expand Down
3 changes: 2 additions & 1 deletion stac_fastapi/opensearch/stac_fastapi/opensearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TokenPaginationExtension,
TransactionExtension,
)
from stac_fastapi.extensions.core.filter import FilterConformanceClasses
from stac_fastapi.extensions.third_party import BulkTransactionExtension
from stac_fastapi.opensearch.config import OpensearchSettings
from stac_fastapi.opensearch.database_logic import (
Expand All @@ -56,7 +57,7 @@
client=EsAsyncBaseFiltersClient(database=database_logic)
)
filter_extension.conformance_classes.append(
"http://www.opengis.net/spec/cql2/1.0/conf/advanced-comparison-operators"
FilterConformanceClasses.ADVANCED_COMPARISON_OPERATORS
)

aggregation_extension = AggregationExtension(
Expand Down
31 changes: 31 additions & 0 deletions stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,37 @@ async def get_items_mapping(self, collection_id: str) -> Dict[str, Any]:
except exceptions.NotFoundError:
raise NotFoundError(f"Mapping for index {index_name} not found")

async def get_items_unique_values(
self, collection_id: str, field_names: Iterable[str], *, limit: int = 100
) -> Dict[str, List[str]]:
"""Get the unique values for the given fields in the collection."""
limit_plus_one = limit + 1
index_name = index_alias_by_collection_id(collection_id)

query = await self.client.search(
index=index_name,
body={
"size": 0,
"aggs": {
field: {"terms": {"field": field, "size": limit_plus_one}}
for field in field_names
},
},
)

result: Dict[str, List[str]] = {}
for field, agg in query["aggregations"].items():
if len(agg["buckets"]) > limit:
logger.warning(
"Skipping enum field %s: exceeds limit of %d unique values. "
"Consider excluding this field from enumeration or increase the limit.",
field,
limit,
)
continue
result[field] = [bucket["key"] for bucket in agg["buckets"]]
return result

async def create_collection(self, collection: Collection, **kwargs: Any):
"""Create a single collection in the database.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Filter client implementation for Elasticsearch/OpenSearch."""

from collections import deque
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

import attr

from stac_fastapi.core.base_database_logic import BaseDatabaseLogic
from stac_fastapi.core.extensions.filter import DEFAULT_QUERYABLES
from stac_fastapi.core.extensions.filter import ALL_QUERYABLES, DEFAULT_QUERYABLES
from stac_fastapi.extensions.core.filter.client import AsyncBaseFiltersClient
from stac_fastapi.sfeos_helpers.mappings import ES_MAPPING_TYPE_TO_JSON

Expand Down Expand Up @@ -59,31 +59,31 @@ async def get_queryables(

mapping_data = await self.database.get_items_mapping(collection_id)
mapping_properties = next(iter(mapping_data.values()))["mappings"]["properties"]
stack = deque(mapping_properties.items())
stack: deque[Tuple[str, Dict[str, Any]]] = deque(mapping_properties.items())
enum_fields: Dict[str, Dict[str, Any]] = {}

while stack:
field_name, field_def = stack.popleft()
field_fqn, field_def = stack.popleft()

# Iterate over nested fields
field_properties = field_def.get("properties")
if field_properties:
# Fields in Item Properties should be exposed with their un-prefixed names,
# and not require expressions to prefix them with properties,
# e.g., eo:cloud_cover instead of properties.eo:cloud_cover.
if field_name == "properties":
stack.extend(field_properties.items())
else:
stack.extend(
(f"{field_name}.{k}", v) for k, v in field_properties.items()
)
stack.extend(
(f"{field_fqn}.{k}", v) for k, v in field_properties.items()
)

# Skip non-indexed or disabled fields
field_type = field_def.get("type")
if not field_type or not field_def.get("enabled", True):
continue

# Fields in Item Properties should be exposed with their un-prefixed names,
# and not require expressions to prefix them with properties,
# e.g., eo:cloud_cover instead of properties.eo:cloud_cover.
field_name = field_fqn.removeprefix("properties.")

# Generate field properties
field_result = DEFAULT_QUERYABLES.get(field_name, {})
field_result = ALL_QUERYABLES.get(field_name, {})
properties[field_name] = field_result

field_name_human = field_name.replace("_", " ").title()
Expand All @@ -95,4 +95,13 @@ async def get_queryables(
if field_type in {"date", "date_nanos"}:
field_result.setdefault("format", "date-time")

if field_result.pop("$enum", False):
enum_fields[field_fqn] = field_result

if enum_fields:
for field_fqn, unique_values in (
await self.database.get_items_unique_values(collection_id, enum_fields)
).items():
enum_fields[field_fqn]["enum"] = unique_values

return queryables
16 changes: 13 additions & 3 deletions stac_fastapi/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from stac_fastapi.core.rate_limit import setup_rate_limit
from stac_fastapi.core.route_dependencies import get_route_dependencies
from stac_fastapi.core.utilities import get_bool_env
from stac_fastapi.extensions.core.filter import FilterConformanceClasses
from stac_fastapi.sfeos_helpers.aggregation import EsAsyncBaseAggregationClient
from stac_fastapi.sfeos_helpers.filter import EsAsyncBaseFiltersClient

if os.getenv("BACKEND", "elasticsearch").lower() == "opensearch":
from stac_fastapi.opensearch.config import AsyncOpensearchSettings as AsyncSettings
Expand All @@ -39,9 +41,11 @@
)
else:
from stac_fastapi.elasticsearch.config import (
ElasticsearchSettings as SearchSettings,
AsyncElasticsearchSettings as AsyncSettings,
)
from stac_fastapi.elasticsearch.config import (
ElasticsearchSettings as SearchSettings,
)
from stac_fastapi.elasticsearch.database_logic import (
DatabaseLogic,
create_collection_index,
Expand Down Expand Up @@ -198,6 +202,13 @@ def bulk_txn_client():
async def app():
settings = AsyncSettings()

filter_extension = FilterExtension(
client=EsAsyncBaseFiltersClient(database=database)
)
filter_extension.conformance_classes.append(
FilterConformanceClasses.ADVANCED_COMPARISON_OPERATORS
)

aggregation_extension = AggregationExtension(
client=EsAsyncBaseAggregationClient(
database=database, session=None, settings=settings
Expand All @@ -217,7 +228,7 @@ async def app():
FieldsExtension(),
QueryExtension(),
TokenPaginationExtension(),
FilterExtension(),
filter_extension,
FreeTextExtension(),
]

Expand Down Expand Up @@ -313,7 +324,6 @@ async def app_client_rate_limit(app_rate_limit):

@pytest_asyncio.fixture(scope="session")
async def app_basic_auth():

stac_fastapi_route_dependencies = """[
{
"routes":[{"method":"*","path":"*"}],
Expand Down
51 changes: 50 additions & 1 deletion stac_fastapi/tests/extensions/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
import logging
import os
import uuid
from os import listdir
from os.path import isfile, join
from typing import Callable, Dict

import pytest
from httpx import AsyncClient

THIS_DIR = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -40,7 +43,6 @@ async def test_filter_extension_collection_link(app_client, load_test_data):

@pytest.mark.asyncio
async def test_search_filters_post(app_client, ctx):

filters = []
pwd = f"{THIS_DIR}/cql2"
for fn in [fn for f in listdir(pwd) if isfile(fn := join(pwd, f))]:
Expand Down Expand Up @@ -625,3 +627,50 @@ async def test_search_filter_extension_cql2text_s_disjoint_property(app_client,
assert resp.status_code == 200
resp_json = resp.json()
assert len(resp_json["features"]) == 1


@pytest.mark.asyncio
async def test_queryables_enum_platform(
app_client: AsyncClient,
load_test_data: Callable[[str], Dict],
monkeypatch: pytest.MonkeyPatch,
):
# Arrange
# Enforce instant database refresh
# TODO: Is there a better way to do this?
monkeypatch.setenv("DATABASE_REFRESH", "true")

# Create collection
collection_data = load_test_data("test_collection.json")
collection_id = collection_data["id"] = f"enum-test-collection-{uuid.uuid4()}"
r = await app_client.post("/collections", json=collection_data)
r.raise_for_status()

# Create items with different platform values
NUM_ITEMS = 3
for i in range(1, NUM_ITEMS + 1):
item_data = load_test_data("test_item.json")
item_data["id"] = f"enum-test-item-{i}"
item_data["collection"] = collection_id
item_data["properties"]["platform"] = "landsat-8" if i % 2 else "sentinel-2"
r = await app_client.post(f"/collections/{collection_id}/items", json=item_data)
r.raise_for_status()

# Act
# Test queryables endpoint
queryables = (
(await app_client.get(f"/collections/{collection_data['id']}/queryables"))
.raise_for_status()
.json()
)

# Assert
# Verify distinct values (should only have 2 unique values despite 3 items)
properties = queryables["properties"]
platform_info = properties["platform"]
platform_values = platform_info["enum"]
assert set(platform_values) == {"landsat-8", "sentinel-2"}

# Clean up
r = await app_client.delete(f"/collections/{collection_id}")
r.raise_for_status()