Skip to content

Upgrade redisvl and improve resource management #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ product_metadata.json
product_vectors.json
data/
!backend/data
.env
.env
.python-version
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Much inspiration taken from [tiangelo/full-stack-fastapi-template](https://githu
product.py # primary API logic lives here
/db
load.py # seeds Redis DB
redis_helpers.py # redis util
utils.py # redis util
/schema
# pydantic models for serialization/validation from API
/tests
Expand Down
283 changes: 255 additions & 28 deletions backend/poetry.lock

Large diffs are not rendered by default.

41 changes: 12 additions & 29 deletions backend/productsearch/api/routes/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,21 @@

import numpy as np
from fastapi import APIRouter, Depends
from redis.commands.search.document import Document
from redis.commands.search.query import Query
from redisvl.index import AsyncSearchIndex
from redisvl.query import FilterQuery, VectorQuery
from redisvl.query.filter import FilterExpression, Tag
from redisvl.query import CountQuery, FilterQuery, VectorQuery
from redisvl.query.filter import Tag

from productsearch import config
from productsearch.api.schema.product import (
ProductSearchResponse,
ProductVectorSearchResponse,
SimilarityRequest,
)
from productsearch.db import redis_helpers

from productsearch.db import utils

router = APIRouter()


def create_count_query(filter_expression: FilterExpression) -> Query:
"""
Create a "count" query where simply want to know how many records
match a particular filter expression

Args:
filter_expression (FilterExpression): The filter expression for the query.

Returns:
Query: The Redis query object.
"""
return Query(str(filter_expression)).no_content().dialect(2)


@router.get(
"/",
response_model=ProductSearchResponse,
Expand All @@ -45,7 +28,7 @@ async def get_products(
skip: int = 0,
gender: str = "",
category: str = "",
index: AsyncSearchIndex = Depends(redis_helpers.get_async_index),
index: AsyncSearchIndex = Depends(utils.get_async_index),
) -> ProductSearchResponse:
"""Fetch and return products based on gender and category fields

Expand Down Expand Up @@ -76,7 +59,7 @@ async def get_products(
)
async def find_products_by_image(
similarity_request: SimilarityRequest,
index: AsyncSearchIndex = Depends(redis_helpers.get_async_index),
index: AsyncSearchIndex = Depends(utils.get_async_index),
) -> ProductVectorSearchResponse:
"""Fetch and return products based on image similarity

Expand Down Expand Up @@ -116,14 +99,14 @@ async def find_products_by_image(
return_fields=config.RETURN_FIELDS,
filter_expression=filter_expression,
)
count_query = create_count_query(filter_expression)
count_query = CountQuery(filter_expression)

# Execute search
count, result_papers = await asyncio.gather(
index.search(count_query), index.query(paper_similarity_query)
index.query(count_query), index.query(paper_similarity_query)
)
# Get Paper records of those results
return ProductVectorSearchResponse(total=count.total, products=result_papers)
return ProductVectorSearchResponse(total=count, products=result_papers)


@router.post(
Expand All @@ -134,7 +117,7 @@ async def find_products_by_image(
)
async def find_products_by_text(
similarity_request: SimilarityRequest,
index: AsyncSearchIndex = Depends(redis_helpers.get_async_index),
index: AsyncSearchIndex = Depends(utils.get_async_index),
) -> ProductVectorSearchResponse:
"""Fetch and return products based on image similarity

Expand Down Expand Up @@ -174,11 +157,11 @@ async def find_products_by_text(
return_fields=config.RETURN_FIELDS,
filter_expression=filter_expression,
)
count_query = create_count_query(filter_expression)
count_query = CountQuery(filter_expression)

# Execute search
count, result_papers = await asyncio.gather(
index.search(count_query), index.query(paper_similarity_query)
index.query(count_query), index.query(paper_similarity_query)
)
# Get Paper records of those results
return ProductVectorSearchResponse(total=count.total, products=result_papers)
return ProductVectorSearchResponse(total=count, products=result_papers)
11 changes: 5 additions & 6 deletions backend/productsearch/db/load.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#!/usr/bin/env python3
import asyncio
import json
import os
from typing import List

import numpy as np
import requests
from productsearch import config
from redisvl.index import AsyncSearchIndex

from productsearch import config
from productsearch.db.utils import get_schema


def read_from_s3():
res = requests.get(config.S3_DATA_URL)
Expand Down Expand Up @@ -58,10 +59,8 @@ def preprocess(product: dict) -> dict:


async def load_data():
index = AsyncSearchIndex.from_yaml(
os.path.join("./productsearch/db/schema", "products.yml")
)
index.connect(config.REDIS_URL)
schema = get_schema()
index = AsyncSearchIndex(schema, redis_url=config.REDIS_URL)

# Check if index exists
if await index.exists() and len((await index.search("*")).docs) > 0:
Expand Down
36 changes: 0 additions & 36 deletions backend/productsearch/db/redis_helpers.py

This file was deleted.

25 changes: 25 additions & 0 deletions backend/productsearch/db/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging
import os

from redisvl.index import AsyncSearchIndex
from redisvl.schema import IndexSchema

from productsearch import config

logger = logging.getLogger(__name__)

# global search index
_global_index = None


def get_schema() -> IndexSchema:
dir_path = os.path.dirname(os.path.realpath(__file__)) + "/schema"
file_path = os.path.join(dir_path, "products.yml")
return IndexSchema.from_yaml(file_path)


async def get_async_index():
global _global_index
if not _global_index:
_global_index = AsyncSearchIndex(get_schema(), redis_url=config.REDIS_URL)
return _global_index
15 changes: 14 additions & 1 deletion backend/productsearch/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import asynccontextmanager
from pathlib import Path

import uvicorn
Expand All @@ -7,10 +8,22 @@

from productsearch import config
from productsearch.api.main import api_router
from productsearch.db.utils import get_async_index
from productsearch.spa import SinglePageApplication


@asynccontextmanager
async def lifespan(app: FastAPI):
index = await get_async_index()
async with index:
yield


app = FastAPI(
title=config.PROJECT_NAME, docs_url=config.API_DOCS, openapi_url=config.OPENAPI_DOCS
title=config.PROJECT_NAME,
docs_url=config.API_DOCS,
openapi_url=config.OPENAPI_DOCS,
lifespan=lifespan,
)

app.add_middleware(
Expand Down
21 changes: 10 additions & 11 deletions backend/productsearch/tests/api/routes/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,29 @@
from httpx import AsyncClient

from productsearch.api.schema.product import SimilarityRequest
from productsearch.main import app


@pytest.fixture
def gender(products):
return products[0]["gender"]
@pytest.fixture(scope="module")
def gender(test_data):
return test_data[0]["gender"]


@pytest.fixture
def category(products):
return products[0]["category"]
@pytest.fixture(scope="module")
def category(test_data):
return test_data[0]["category"]


@pytest.fixture
@pytest.fixture(scope="module")
def bad_req_json():
return {"not": "valid"}


@pytest.fixture
def product_req(gender, category, products):
@pytest.fixture(scope="module")
def product_req(gender, category, test_data):
return SimilarityRequest(
gender=gender,
category=category,
product_id=products[0]["product_id"],
product_id=test_data[0]["product_id"],
)


Expand Down
65 changes: 44 additions & 21 deletions backend/productsearch/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,59 @@
from asyncio import get_event_loop
from typing import Generator
import json
import os

import httpx
import numpy as np
import pytest
import pytest_asyncio
from asgi_lifespan import LifespanManager
from httpx import AsyncClient
from redis.asyncio import Redis
from redisvl.index import SearchIndex

from productsearch import config
from productsearch.db.utils import get_schema
from productsearch.main import app
from productsearch.tests.utils.seed import seed_test_db


@pytest.fixture(scope="module")
def products():
products = seed_test_db()
return products
@pytest.fixture(scope="session")
def index():
index = SearchIndex(schema=get_schema(), redis_url=config.REDIS_URL)
index.create()
yield index
index.disconnect()


@pytest.fixture
async def client():
client = await Redis.from_url(config.REDIS_URL)
yield client
try:
await client.aclose()
except RuntimeError as e:
if "Event loop is closed" not in str(e):
raise
@pytest.fixture(scope="session", autouse=True)
def test_data(index):
cwd = os.getcwd()
with open(f"{cwd}/productsearch/tests/test_vectors.json", "r") as f:
products = json.load(f)

parsed_products = []

@pytest_asyncio.fixture(scope="session")
async def async_client():
# convert to bytes
for product in products:
parsed = {}
parsed["text_vector"] = np.array(
product["text_vector"], dtype=np.float32
).tobytes()
parsed["img_vector"] = np.array(
product["img_vector"], dtype=np.float32
).tobytes()
parsed["category"] = product["product_metadata"]["master_category"]
parsed["img_url"] = product["product_metadata"]["img_url"]
parsed["name"] = product["product_metadata"]["name"]
parsed["gender"] = product["product_metadata"]["gender"]
parsed["product_id"] = product["product_id"]
parsed_products.append(parsed)

_ = index.load(data=parsed_products, id_field="product_id")
return parsed_products

async with AsyncClient(app=app, base_url="http://test/api/v1/") as client:

yield client
@pytest_asyncio.fixture(scope="session")
async def async_client():
async with LifespanManager(app=app) as lifespan:
async with AsyncClient(
transport=httpx.ASGITransport(app=app), base_url="http://test/api/v1/" # type: ignore
) as client:
yield client
Empty file.
34 changes: 0 additions & 34 deletions backend/productsearch/tests/utils/seed.py

This file was deleted.

Loading