Skip to content

Adapt to the lmstudio embedded provider and add example config #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import setuptools
from setuptools import find_packages

with open("readme.md", "r") as fh:
with open("readme.md", "r", encoding="utf-8") as fh:
long_description = fh.read()


Expand Down
14 changes: 13 additions & 1 deletion src/server/api/config.yaml.example
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
llm_api_key: YOUR-OPENAI-KEY
llm_api_key: YOUR-OPENAI-KEY

# If you use the lstudio embedded provider, you can refer to the configuration below

# embedding_provider: lmstudio
# embedding_api_key: lm_XXX
# embedding_model: text-embedding-qwen3-embedding-8b
# If you start with docker and access lmstudio or 127.0.0.1 via localhost, you can refer to the following base url config
# embedding_base_url: http://host.docker.internal:1234/v1
# embedding_base_url: http://127.0.0.1:1234/v1
# embedding_dim: 4096

# language: zh
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from ...models.database import DEFAULT_PROJECT_ID
from .jina_embedding import jina_embedding
from .openai_embedding import openai_embedding
from .lmstudio_embedding import lmstudio_embedding
from ...telemetry import telemetry_manager, HistogramMetricName, CounterMetricName
from ...utils import get_encoded_tokens

FACTORIES = {"openai": openai_embedding, "jina": jina_embedding}
FACTORIES = {"openai": openai_embedding, "jina": jina_embedding, "lmstudio": lmstudio_embedding}
assert (
CONFIG.embedding_provider in FACTORIES
), f"Unsupported embedding provider: {CONFIG.embedding_provider}"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
from typing import Literal
from ...errors import ExternalAPIError
from ...env import CONFIG, LOG
from .utils import get_lmstudio_async_client_instance

LMSTUDIO_TASK = {
"query": "retrieval.query",
"document": "retrieval.passage",
}

async def lmstudio_embedding(
model: str, texts: list[str], phase: Literal["query", "document"] = "document"
) -> np.ndarray:
lmstudio_async_client = get_lmstudio_async_client_instance()
response = await lmstudio_async_client.post(
"/embeddings",
json={
"model": model,
"input": texts,
"task": LMSTUDIO_TASK[phase],
"truncate": True,
"dimensions": CONFIG.embedding_dim,
},
timeout=20,
)
if response.status_code != 200:
raise ExternalAPIError(f"Failed to embed texts: {response.text}")
data = response.json()
LOG.info(
f"lmstudio embedding, {model}, {phase}, {data['usage']['prompt_tokens']}/{data['usage']['total_tokens']}"
)
return np.array([dp["embedding"] for dp in data["data"]])
10 changes: 10 additions & 0 deletions src/server/api/memobase_server/llms/embeddings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

_global_openai_async_client = None
_global_jina_async_client = None
_global_lmstudio_async_client = None


def get_openai_async_client_instance() -> AsyncOpenAI:
Expand All @@ -24,3 +25,12 @@ def get_jina_async_client_instance() -> AsyncClient:
headers={"Authorization": f"Bearer {CONFIG.embedding_api_key}"},
)
return _global_jina_async_client

def get_lmstudio_async_client_instance() -> AsyncClient:
global _global_lmstudio_async_client
if _global_lmstudio_async_client is None:
_global_lmstudio_async_client = AsyncClient(
base_url=CONFIG.embedding_base_url,
headers={"Authorization": f"Bearer {CONFIG.embedding_api_key}"},
)
return _global_lmstudio_async_client