Skip to content

Commit eae5683

Browse files
Update code so it can run with or without posters
1 parent d5617bb commit eae5683

File tree

3 files changed

+49
-29
lines changed

3 files changed

+49
-29
lines changed

recipes/weaviate/demo_app.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
# Constants
1010
ENV_VARS = ["WEAVIATE_URL", "WEAVIATE_API_KEY", "COHERE_API_KEY"]
11-
NUM_IMAGES_PER_ROW = 5
11+
NUM_RECOMMENDATIONS_PER_ROW = 5
1212
SEARCH_LIMIT = 10
1313

1414
# Search Mode descriptions
@@ -34,11 +34,18 @@ def display_chat_messages():
3434
with st.chat_message(message["role"]):
3535
st.markdown(message["content"])
3636
if "images" in message:
37-
for i in range(0, len(message["images"]), NUM_IMAGES_PER_ROW):
38-
cols = st.columns(NUM_IMAGES_PER_ROW)
37+
for i in range(0, len(message["images"]), NUM_RECOMMENDATIONS_PER_ROW):
38+
cols = st.columns(NUM_RECOMMENDATIONS_PER_ROW)
3939
for j, col in enumerate(cols):
4040
if i + j < len(message["images"]):
4141
col.image(message["images"][i + j], width=200)
42+
if "titles" in message:
43+
for i in range(0, len(message["titles"]), NUM_RECOMMENDATIONS_PER_ROW):
44+
cols = st.columns(NUM_RECOMMENDATIONS_PER_ROW)
45+
for j, col in enumerate(cols):
46+
if i + j < len(message["titles"]):
47+
col.write(message["titles"][i + j])
48+
4249

4350
def base64_to_image(base64_str):
4451
"""Convert base64 string to image"""
@@ -112,7 +119,10 @@ def perform_search(conn, movie_type, rag_prompt, year_range, mode):
112119
df = conn.query(
113120
"MovieDemo",
114121
query=movie_type,
115-
return_properties=["title", "tagline", "poster"],
122+
# Uncomment the line below if you want to use this with poster images
123+
# return_properties=["title", "tagline", "poster"],
124+
# Comment out the line below if you want to use this with poster images
125+
return_properties=["title", "tagline"],
116126
filters=(
117127
WeaviateFilter.by_property("release_year").greater_or_equal(year_range[0]) &
118128
WeaviateFilter.by_property("release_year").less_or_equal(year_range[1])
@@ -122,6 +132,8 @@ def perform_search(conn, movie_type, rag_prompt, year_range, mode):
122132
)
123133

124134
images = []
135+
titles = []
136+
125137
if df is None or df.empty:
126138
with st.chat_message("assistant"):
127139
st.write(f"No movies found matching {movie_type} and using {mode}. Please try again.")
@@ -130,16 +142,21 @@ def perform_search(conn, movie_type, rag_prompt, year_range, mode):
130142
else:
131143
with st.chat_message("assistant"):
132144
st.write("Raw search results.")
133-
cols = st.columns(NUM_IMAGES_PER_ROW)
145+
cols = st.columns(NUM_RECOMMENDATIONS_PER_ROW)
134146
for index, row in df.iterrows():
135-
col = cols[index % NUM_IMAGES_PER_ROW]
136-
col.write(row['title'])
147+
col = cols[index % NUM_RECOMMENDATIONS_PER_ROW]
148+
if "poster" in row and row["poster"]:
149+
col.image(base64_to_image(row["poster"]), width=200)
150+
images.append(base64_to_image(row["poster"]))
151+
else:
152+
col.write(f"{row['title']}")
153+
titles.append(row["title"])
154+
137155
st.write("Now generating recommendation from these: ...")
138156

139157
st.session_state.messages.append(
140-
{"role": "assistant", "content": "Raw search results. Generating recommendation from these: ...", "images": images}
141-
)
142-
158+
{"role": "assistant", "content": "Raw search results. Generating recommendation from these: ...", "images": images, "titles": titles})
159+
143160
with conn.client() as client:
144161
collection = client.collections.get("MovieDemo")
145162
response = collection.generate.hybrid(

recipes/weaviate/helpers/add_data.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,11 @@
8383
name="budget",
8484
data_type=DataType.INT,
8585
),
86-
Property(
87-
name="poster",
88-
data_type=DataType.BLOB
89-
),
86+
# Uncomment the lines below if you want to use this with poster images
87+
# Property(
88+
# name="poster",
89+
# data_type=DataType.BLOB
90+
# ),
9091
],
9192
vectorizer_config=Configure.Vectorizer.text2vec_cohere(),
9293
vector_index_config=Configure.VectorIndex.hnsw(
@@ -109,9 +110,10 @@
109110
date_object = datetime.strptime(movie_row["release_date"], "%Y-%m-%d").replace(
110111
tzinfo=timezone.utc
111112
)
112-
img_path = (img_dir / f"{movie_row['id']}_poster.jpg")
113-
with open(img_path, "rb") as file:
114-
poster_b64 = base64.b64encode(file.read()).decode("utf-8")
113+
# Uncomment the lines below if you want to use this with poster images
114+
# img_path = (img_dir / f"{movie_row['id']}_poster.jpg")
115+
# with open(img_path, "rb") as file:
116+
# poster_b64 = base64.b64encode(file.read()).decode("utf-8")
115117

116118
props = {
117119
k: movie_row[k]
@@ -128,7 +130,8 @@
128130
props["movie_id"] = movie_row["id"]
129131
props["release_year"] = date_object.year
130132
props["genres"] = [genre["name"] for genre in movie_row["genres"]]
131-
props["poster"] = poster_b64
133+
# Uncomment the line below if you want to use this with poster images
134+
# props["poster"] = poster_b64
132135

133136
batch.add_object(properties=props, uuid=generate_uuid5(movie_row["id"]))
134137
except Exception as e:

recipes/weaviate/helpers/verify_data.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,38 +5,38 @@
55
import os
66

77

8-
config = toml.load("/Users/lacosta/Desktop/PROJECTS/cookbook/recipes/weaviate/.streamlit/secrets.toml")
8+
# Construct the path to the toml file
9+
current_dir = os.path.dirname(__file__)
10+
parent_dir = os.path.dirname(current_dir)
11+
toml_file_path = os.path.join(parent_dir, ".streamlit/secrets.toml")
12+
13+
config = toml.load(toml_file_path)
914

1015
# Access values from the TOML file
1116
weaviate_api_key = config["WEAVIATE_API_KEY"]
1217
weaviate_url = config["WEAVIATE_URL"]
1318
cohere_api_key = config["COHERE_API_KEY"]
1419

15-
weaviate_url = config["WEAVIATE_URL"]
16-
weaviate_apikey = config["WEAVIATE_API_KEY"] # WCS_DEMO_ADMIN_KEY or WCS_DEMO_RO_KEY
17-
cohere_apikey = config["COHERE_API_KEY"]
18-
19-
2020
client = weaviate.connect_to_weaviate_cloud(
2121
cluster_url=weaviate_url,
22-
auth_credentials=Auth.api_key(weaviate_apikey),
22+
auth_credentials=Auth.api_key(weaviate_api_key),
2323
headers={
24-
"X-Cohere-Api-Key": cohere_apikey
24+
"X-Cohere-Api-Key": cohere_api_key
2525
}
2626
)
2727

2828
# # If you are using a local instance of Weaviate, you can use the following code
2929
# client = weaviate.connect_to_local(
3030
# headers={
31-
# "X-Cohere-Api-Key": cohere_apikey
31+
# "X-Cohere-Api-Key": cohere_api_key
3232
# }
3333
# )
3434

3535
movies = client.collections.get("MovieDemo")
3636

3737
print(movies.aggregate.over_all(total_count=True))
3838

39-
r = movies.query.fetch_objects(limit=1, return_properties=["poster"])
40-
print(r.objects[0].properties["poster"])
39+
r = movies.query.fetch_objects(limit=1, return_properties=["title"])
40+
print(r.objects[0].properties["title"])
4141

4242
client.close()

0 commit comments

Comments
 (0)