Skip to content

Commit bfdae80

Browse files
updated conversation items DB model to have a row for each item
Signed-off-by: Francisco Javier Arceo <[email protected]>
1 parent b38e6df commit bfdae80

File tree

2 files changed

+63
-54
lines changed

2 files changed

+63
-54
lines changed

llama_stack/core/conversations/conversations.py

Lines changed: 62 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ async def initialize(self) -> None:
8181
},
8282
)
8383

84+
await self.sql_store.create_table(
85+
"conversation_items",
86+
{
87+
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
88+
"conversation_id": ColumnType.STRING,
89+
"created_at": ColumnType.INTEGER,
90+
"item_data": ColumnType.JSON,
91+
},
92+
)
93+
8494
async def create_conversation(
8595
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
8696
) -> Conversation:
@@ -158,31 +168,41 @@ async def create(self, conversation_id: str, items: list[ConversationItem]) -> C
158168
await self._get_validated_conversation(conversation_id)
159169

160170
created_items = []
171+
created_at = int(time.time())
161172

162173
for item in items:
163-
# Generate item ID based on item type
164-
random_bytes = secrets.token_bytes(24)
165-
if item.type == "message":
166-
item_id = f"msg_{random_bytes.hex()}"
167-
else:
168-
item_id = f"item_{random_bytes.hex()}"
169-
170-
# Create a copy of the item with the generated ID and completed status
171174
item_dict = item.model_dump()
172-
item_dict["id"] = item_id
173-
if "status" not in item_dict:
174-
item_dict["status"] = "completed"
175-
176-
created_items.append(item_dict)
177175

178-
# Get existing items from database
179-
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
180-
existing_items = record.get("items", []) if record else []
176+
# Generate ID if not present
177+
if item.id is None:
178+
random_bytes = secrets.token_bytes(24)
179+
if item.type == "message":
180+
item_id = f"msg_{random_bytes.hex()}"
181+
else:
182+
item_id = f"item_{random_bytes.hex()}"
183+
item_dict["id"] = item_id
184+
else:
185+
item_id = item.id
186+
187+
item_record = {
188+
"id": item_id,
189+
"conversation_id": conversation_id,
190+
"created_at": created_at,
191+
"item_data": item_dict,
192+
}
193+
194+
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
195+
try:
196+
await self.sql_store.insert(table="conversation_items", data=item_record)
197+
except Exception:
198+
# If insert fails due to ID conflict, update existing record
199+
await self.sql_store.update(
200+
table="conversation_items",
201+
data={"created_at": created_at, "item_data": item_dict},
202+
where={"id": item_id},
203+
)
181204

182-
updated_items = existing_items + created_items
183-
await self.sql_store.update(
184-
table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id}
185-
)
205+
created_items.append(item_dict)
186206

187207
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
188208

@@ -204,39 +224,37 @@ async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem
204224
if not item_id:
205225
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
206226

207-
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
208-
items = record.get("items", []) if record else []
227+
# Get item from conversation_items table
228+
record = await self.sql_store.fetch_one(
229+
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
230+
)
209231

210-
for item in items:
211-
if isinstance(item, dict) and item.get("id") == item_id:
212-
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
213-
return adapter.validate_python(item)
232+
if record is None:
233+
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
214234

215-
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
235+
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
236+
return adapter.validate_python(record["item_data"])
216237

217238
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
218239
"""List items in the conversation."""
219-
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
220-
items = record.get("items", []) if record else []
240+
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
241+
records = result.data
221242

222243
if order != NOT_GIVEN and order == "asc":
223-
items = items
244+
records.sort(key=lambda x: x["created_at"])
224245
else:
225-
items = list(reversed(items))
246+
records.sort(key=lambda x: x["created_at"], reverse=True)
226247

227248
actual_limit = 20
228249
if limit != NOT_GIVEN and isinstance(limit, int):
229250
actual_limit = limit
230251

231-
items = items[:actual_limit]
252+
records = records[:actual_limit]
253+
items = [record["item_data"] for record in records]
232254

233-
# Items from database are stored as dicts, convert them to ConversationItem
234255
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
235-
response_items: list[ConversationItem] = [
236-
adapter.validate_python(item) if isinstance(item, dict) else item for item in items
237-
]
256+
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
238257

239-
# Get first and last IDs from converted response items
240258
first_id = response_items[0].id if response_items else None
241259
last_id = response_items[-1].id if response_items else None
242260

@@ -256,26 +274,17 @@ async def openai_delete_conversation_item(
256274
if not item_id:
257275
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
258276

259-
_ = await self._get_validated_conversation(conversation_id) # executes validation
277+
_ = await self._get_validated_conversation(conversation_id)
260278

261-
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
262-
items = record.get("items", []) if record else []
263-
264-
updated_items = []
265-
item_found = False
266-
267-
for item in items:
268-
current_item_id = item.get("id") if isinstance(item, dict) else getattr(item, "id", None)
269-
if current_item_id != item_id:
270-
updated_items.append(item)
271-
else:
272-
item_found = True
279+
record = await self.sql_store.fetch_one(
280+
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
281+
)
273282

274-
if not item_found:
283+
if record is None:
275284
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
276285

277-
await self.sql_store.update(
278-
table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id}
286+
await self.sql_store.delete(
287+
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
279288
)
280289

281290
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")

tests/unit/conversations/test_conversations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def test_conversation_items(service):
6262
item_list = await service.create(conversation.id, items)
6363

6464
assert len(item_list.data) == 1
65-
assert item_list.data[0].id.startswith("msg_")
65+
assert item_list.data[0].id == "msg_test123"
6666

6767
items = await service.list(conversation.id)
6868
assert len(items.data) == 1

0 commit comments

Comments
 (0)