Skip to content

Commit 845b9f3

Browse files
committed
fix(storage): misc fixes to serialization of basemodels
1 parent 3816be3 commit 845b9f3

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

src/storage/src/storage3/_async/vectors.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ async def create_index(
5454
dimension=dimension,
5555
distanceMetric=distance_metric,
5656
dataType=data_type,
57-
metadataConfiguration=dict(metadata) if metadata else None,
57+
metadataConfiguration=metadata.model_dump(by_alias=True)
58+
if metadata
59+
else None,
5860
)
5961
await self._request.send(http_method="POST", path=["CreateIndex"], body=body)
6062

@@ -106,7 +108,9 @@ def with_metadata(self, **data: JSON) -> JSON:
106108
)
107109

108110
async def put(self, vectors: List[VectorObject]) -> None:
109-
body = self.with_metadata(vectors=[v.as_json() for v in vectors])
111+
body = self.with_metadata(
112+
vectors=[v.model_dump(exclude_none=True) for v in vectors]
113+
)
110114
await self._request.send(http_method="POST", path=["PutVectors"], body=body)
111115

112116
async def get(
@@ -149,7 +153,7 @@ async def query(
149153
filter: Optional[VectorFilter] = None,
150154
return_distance: bool = True,
151155
return_metadata: bool = True,
152-
) -> QueryVectorsResponse:
156+
) -> List[VectorMatch]:
153157
body = self.with_metadata(
154158
queryVector=dict(query_vector),
155159
topK=topK,
@@ -160,10 +164,10 @@ async def query(
160164
data = await self._request.send(
161165
http_method="POST", path=["QueryVectors"], body=body
162166
)
163-
return QueryVectorsResponse.model_validate_json(data.content)
167+
return QueryVectorsResponse.model_validate_json(data.content).vectors
164168

165169
async def delete(self, keys: List[str]) -> None:
166-
if 1 < len(keys) or len(keys) > 500:
170+
if len(keys) < 1 or len(keys) > 500:
167171
raise VectorBucketException("Keys batch size must be between 1 and 500.")
168172
body = self.with_metadata(keys=keys)
169173
await self._request.send(http_method="POST", path=["DeleteVectors"], body=body)

src/storage/src/storage3/types.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,14 +207,11 @@ class VectorObject(BaseModel, extra=Extra.ignore):
207207
data: VectorData
208208
metadata: Optional[dict[str, Union[str, bool, float]]] = None
209209

210-
def as_json(self) -> JSON:
211-
return {"key": self.key, "data": dict(self.data), "metadata": self.metadata}
212-
213210

214211
class VectorMatch(BaseModel, extra=Extra.ignore):
215212
key: str
216213
data: Optional[VectorData] = None
217-
distance: Optional[int] = None
214+
distance: Optional[float] = None
218215
metadata: Optional[dict[str, Any]] = None
219216

220217

@@ -228,7 +225,7 @@ class ListVectorsResponse(BaseModel, extra=Extra.ignore):
228225

229226

230227
class QueryVectorsResponse(BaseModel, extra=Extra.ignore):
231-
matches: List[VectorMatch]
228+
vectors: List[VectorMatch]
232229

233230

234231
class AnalyticsBucket(BaseModel, extra=Extra.ignore):

0 commit comments

Comments
 (0)