Skip to content

Commit 74bb7dd

Browse files
committed
refactor(cli): Modify the logics of metadata processing to avoid passing unnecessary metadata.
1 parent 681b85f commit 74bb7dd

File tree

2 files changed

+37
-19
lines changed

2 files changed

+37
-19
lines changed

src/vectorcode/common.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -182,16 +182,11 @@ async def get_collection(
182182
logger.debug(
183183
f"Getting/Creating collection with the following metadata: {collection_meta}"
184184
)
185-
if not make_if_missing:
186-
__COLLECTION_CACHE[full_path] = await client.get_collection(
185+
try:
186+
collection = await client.get_collection(
187187
collection_name, embedding_function
188188
)
189-
else:
190-
collection = await client.get_or_create_collection(
191-
collection_name,
192-
metadata=collection_meta,
193-
embedding_function=embedding_function,
194-
)
189+
__COLLECTION_CACHE[full_path] = collection
195190
if (
196191
not collection.metadata.get("hostname") == socket.gethostname()
197192
or collection.metadata.get("username")
@@ -208,7 +203,17 @@ async def get_collection(
208203
raise IndexError(
209204
"Failed to create the collection due to hash collision. Please file a bug report."
210205
)
211-
__COLLECTION_CACHE[full_path] = collection
206+
except ValueError:
207+
if make_if_missing:
208+
collection = await client.create_collection(
209+
collection_name,
210+
metadata=collection_meta,
211+
embedding_function=embedding_function,
212+
)
213+
214+
__COLLECTION_CACHE[full_path] = collection
215+
else:
216+
raise
212217
return __COLLECTION_CACHE[full_path]
213218

214219

tests/test_common.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,16 @@ async def test_get_collection():
224224
with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient:
225225
mock_client = MagicMock(spec=AsyncClientAPI)
226226
mock_collection = MagicMock()
227+
mock_collection.metadata = {
228+
"path": config.project_root,
229+
"hostname": socket.gethostname(),
230+
"created-by": "VectorCode",
231+
"username": os.environ.get(
232+
"USER", os.environ.get("USERNAME", "DEFAULT_USER")
233+
),
234+
"embedding_function": config.embedding_function,
235+
"hnsw:M": 64,
236+
}
227237
mock_client.get_collection.return_value = mock_collection
228238
MockAsyncHttpClient.return_value = mock_client
229239

@@ -252,7 +262,7 @@ async def test_get_collection():
252262
"created-by": "VectorCode",
253263
}
254264

255-
async def mock_get_or_create_collection(
265+
async def mock_create_collection(
256266
self,
257267
name=None,
258268
configuration=None,
@@ -263,7 +273,7 @@ async def mock_get_or_create_collection(
263273
mock_collection.metadata.update(metadata or {})
264274
return mock_collection
265275

266-
mock_client.get_or_create_collection.side_effect = mock_get_or_create_collection
276+
mock_client.create_collection.side_effect = mock_create_collection
267277
MockAsyncHttpClient.return_value = mock_client
268278

269279
collection = await get_collection(mock_client, config, make_if_missing=True)
@@ -273,16 +283,18 @@ async def mock_get_or_create_collection(
273283
)
274284
assert collection.metadata["created-by"] == "VectorCode"
275285
assert collection.metadata["hnsw:M"] == 64
276-
mock_client.get_or_create_collection.assert_called_once()
286+
mock_client.create_collection.assert_called_once()
277287
mock_client.get_collection.side_effect = None
278288

279289
# Test raising IndexError on hash collision.
280-
with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient:
290+
with (
291+
patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient,
292+
patch("socket.gethostname", side_effect=(lambda: "dummy")),
293+
):
281294
mock_client = MagicMock(spec=AsyncClientAPI)
282-
mock_client.get_or_create_collection.side_effect = IndexError(
283-
"Hash collision occurred"
284-
)
295+
285296
MockAsyncHttpClient.return_value = mock_client
297+
mock_client.get_collection = AsyncMock(return_value=mock_collection)
286298
from vectorcode.common import __COLLECTION_CACHE
287299

288300
__COLLECTION_CACHE.clear()
@@ -315,7 +327,8 @@ async def test_get_collection_hnsw():
315327
"embedding_function": "SentenceTransformerEmbeddingFunction",
316328
"path": "/test_project",
317329
}
318-
mock_client.get_or_create_collection.return_value = mock_collection
330+
mock_client.create_collection.return_value = mock_collection
331+
mock_client.get_collection.side_effect = ValueError
319332
MockAsyncHttpClient.return_value = mock_client
320333

321334
# Clear the collection cache to force creation
@@ -332,9 +345,9 @@ async def test_get_collection_hnsw():
332345
assert collection.metadata["created-by"] == "VectorCode"
333346
assert collection.metadata["hnsw:ef_construction"] == 200
334347
assert collection.metadata["hnsw:M"] == 32
335-
mock_client.get_or_create_collection.assert_called_once()
348+
mock_client.create_collection.assert_called_once()
336349
assert (
337-
mock_client.get_or_create_collection.call_args.kwargs["metadata"]
350+
mock_client.create_collection.call_args.kwargs["metadata"]
338351
== mock_collection.metadata
339352
)
340353

0 commit comments

Comments
 (0)