Skip to content

Commit c03c490

Browse files
committed
feat(cli): Remove client cache and fix termination issues
1 parent 9f20ac8 commit c03c490

File tree

3 files changed

+31
-30
lines changed

3 files changed

+31
-30
lines changed

src/vectorcode/cli_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ class LockManager:
622622
__singleton: "LockManager"
623623

624624
def __new__(cls) -> "LockManager":
625-
if cls.__singleton is None:
625+
if not hasattr(cls, "__singleton") or cls.__singleton is None:
626626
cls.__singleton = super().__new__(cls)
627627
cls.__singleton.__locks = {}
628628
return cls.__singleton

src/vectorcode/common.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -114,30 +114,24 @@ async def start_server(configs: Config):
114114
return process
115115

116116

117-
__CLIENT_CACHE: dict[str, AsyncClientAPI] = {}
118-
119-
120117
async def get_client(configs: Config) -> AsyncClientAPI:
121-
client_entry = configs.db_url
122-
if __CLIENT_CACHE.get(client_entry) is None:
123-
settings: dict[str, Any] = {"anonymized_telemetry": False}
124-
if isinstance(configs.db_settings, dict):
125-
valid_settings = {
126-
k: v for k, v in configs.db_settings.items() if k in Settings.__fields__
127-
}
128-
settings.update(valid_settings)
129-
parsed_url = urlparse(configs.db_url)
130-
settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1"
131-
settings["chroma_server_http_port"] = parsed_url.port or 8000
132-
settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https"
133-
settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2
134-
settings_obj = Settings(**settings)
135-
__CLIENT_CACHE[client_entry] = await chromadb.AsyncHttpClient(
136-
settings=settings_obj,
137-
host=str(settings_obj.chroma_server_host),
138-
port=int(settings_obj.chroma_server_http_port or 8000),
139-
)
140-
return __CLIENT_CACHE[client_entry]
118+
settings: dict[str, Any] = {"anonymized_telemetry": False}
119+
if isinstance(configs.db_settings, dict):
120+
valid_settings = {
121+
k: v for k, v in configs.db_settings.items() if k in Settings.__fields__
122+
}
123+
settings.update(valid_settings)
124+
parsed_url = urlparse(configs.db_url)
125+
settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1"
126+
settings["chroma_server_http_port"] = parsed_url.port or 8000
127+
settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https"
128+
settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2
129+
settings_obj = Settings(**settings)
130+
return await chromadb.AsyncHttpClient(
131+
settings=settings_obj,
132+
host=str(settings_obj.chroma_server_host),
133+
port=int(settings_obj.chroma_server_http_port or 8000),
134+
)
141135

142136

143137
def get_collection_name(full_path: str) -> str:
@@ -277,7 +271,7 @@ class ClientManager:
277271
__clients: dict[str, _ClientModel]
278272

279273
def __new__(cls) -> "ClientManager":
280-
if cls.__singleton is None:
274+
if not hasattr(cls, "__singleton") or cls.__singleton is None:
281275
cls.__singleton = super().__new__(cls)
282276
cls.__singleton.__clients = {}
283277
return cls.__singleton

src/vectorcode/mcp_main.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
get_project_config,
4141
load_config_file,
4242
)
43-
from vectorcode.common import ClientManager, get_client, get_collection, get_collections
43+
from vectorcode.common import ClientManager, get_collection, get_collections
4444
from vectorcode.subcommands.prompt import prompt_by_categories
4545
from vectorcode.subcommands.query import get_query_result_files
4646

@@ -238,7 +238,7 @@ async def mcp_server():
238238
default_project_root = project_root
239239
default_config = await get_project_config(project_root)
240240
default_config.project_root = project_root
241-
default_client = await get_client(default_config)
241+
default_client = (await ClientManager().get_client(default_config)).client
242242
try:
243243
default_collection = await get_collection(default_client, default_config)
244244
logger.info("Collection initialised for %s.", project_root)
@@ -295,9 +295,16 @@ def parse_cli_args(args: Optional[list[str]] = None) -> MCPConfig:
295295

296296

297297
async def run_server(): # pragma: nocover
298-
mcp = await mcp_server()
299-
await mcp.run_stdio_async()
300-
return 0
298+
try:
299+
mcp = await mcp_server()
300+
await mcp.run_stdio_async()
301+
finally:
302+
termination_tasks: list[asyncio.Task] = []
303+
for p in ClientManager().get_processes():
304+
p.terminate()
305+
termination_tasks.append(asyncio.create_task(p.wait()))
306+
await asyncio.gather(*termination_tasks)
307+
return 0
301308

302309

303310
def main(): # pragma: nocover

0 commit comments

Comments
 (0)