Skip to content

Commit dd2f141

Browse files
committed
modify embedding sentence_transformers
1 parent f52824a commit dd2f141

File tree

2 files changed

+85
-2
lines changed

2 files changed

+85
-2
lines changed

xinference/model/embedding/vllm/core.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,34 @@ def load(self):
8989
is_matryoshka=True,
9090
)
9191

92+
# Set appropriate VLLM configuration parameters based on model capabilities
93+
model_max_tokens = getattr(self.model_family, "max_tokens", 512)
94+
95+
# Set max_model_len based on model family capabilities with reasonable limits
96+
max_model_len = min(model_max_tokens, 8192)
97+
if "max_model_len" not in self._kwargs:
98+
self._kwargs["max_model_len"] = max_model_len
99+
100+
# Ensure max_num_batched_tokens is sufficient for large models
101+
if "max_num_batched_tokens" not in self._kwargs:
102+
# max_num_batched_tokens should be at least max_model_len
103+
# Set to a reasonable minimum that satisfies the constraint
104+
self._kwargs["max_num_batched_tokens"] = max(4096, max_model_len)
105+
106+
# Configure other reasonable defaults for embedding models
107+
if "gpu_memory_utilization" not in self._kwargs:
108+
self._kwargs["gpu_memory_utilization"] = 0.7
109+
110+
# Use a smaller block size for better compatibility
111+
if "block_size" not in self._kwargs:
112+
self._kwargs["block_size"] = 16
113+
114+
logger.debug(
115+
f"VLLM configuration for {self.model_family.model_name}: "
116+
f"max_model_len={self._kwargs.get('max_model_len')}, "
117+
f"max_num_batched_tokens={self._kwargs.get('max_num_batched_tokens')}"
118+
)
119+
92120
self._model = LLM(model=self._model_path, task="embed", **self._kwargs)
93121
self._tokenizer = self._model.get_tokenizer()
94122

@@ -246,6 +274,21 @@ def _set_context_length(self):
246274
self._model.llm_engine.vllm_config.model_config.max_model_len
247275
)
248276
else:
249-
# v1
250-
logger.warning("vLLM v1 is not supported, ignore context length setting")
277+
# v1 - Get max_model_len from the v1 engine configuration
278+
try:
279+
# For v1, access the config differently
280+
if hasattr(self._model.llm_engine, "vllm_config"):
281+
self._context_length = (
282+
self._model.llm_engine.vllm_config.model_config.max_model_len
283+
)
284+
elif hasattr(self._model.llm_engine, "model_config"):
285+
self._context_length = (
286+
self._model.llm_engine.model_config.max_model_len
287+
)
288+
else:
289+
# Fallback to the configured value
290+
self._context_length = self._kwargs.get("max_model_len", 512)
291+
except Exception as e:
292+
logger.warning(f"Failed to get context length from vLLM v1 engine: {e}")
293+
self._context_length = self._kwargs.get("max_model_len", 512)
251294
logger.debug("Model context length: %s", self._context_length)

xinference/model/rerank/vllm/core.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import importlib.util
2+
import json
3+
import logging
24
import uuid
35
from typing import List, Optional, Union
46

57
from ....types import Document, DocumentObj, Meta, Rerank, RerankTokens
68
from ...utils import cache_clean
79
from ..core import RerankModel, RerankModelFamilyV2, RerankSpecV1
810

11+
logger = logging.getLogger(__name__)
12+
913
SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "qwen3"]
1014

1115

@@ -67,6 +71,42 @@ def load(self):
6771
classifier_from_token=["no", "yes"],
6872
is_original_qwen3_reranker=True,
6973
)
74+
elif isinstance(self._kwargs["hf_overrides"], str):
75+
self._kwargs["hf_overrides"] = json.loads(self._kwargs["hf_overrides"])
76+
self._kwargs["hf_overrides"].update(
77+
architectures=["Qwen3ForSequenceClassification"],
78+
classifier_from_token=["no", "yes"],
79+
is_original_qwen3_reranker=True,
80+
)
81+
82+
# Set appropriate VLLM configuration parameters based on model capabilities
83+
model_max_tokens = getattr(self.model_family, "max_tokens", 512)
84+
85+
# Set max_model_len based on model family capabilities with reasonable limits
86+
max_model_len = min(model_max_tokens, 8192)
87+
if "max_model_len" not in self._kwargs:
88+
self._kwargs["max_model_len"] = max_model_len
89+
90+
# Ensure max_num_batched_tokens is sufficient for large models
91+
if "max_num_batched_tokens" not in self._kwargs:
92+
# max_num_batched_tokens should be at least max_model_len
93+
# Set to a reasonable minimum that satisfies the constraint
94+
self._kwargs["max_num_batched_tokens"] = max(4096, max_model_len)
95+
96+
# Configure other reasonable defaults for reranking models
97+
if "gpu_memory_utilization" not in self._kwargs:
98+
self._kwargs["gpu_memory_utilization"] = 0.7
99+
100+
# Use a smaller block size for better compatibility
101+
if "block_size" not in self._kwargs:
102+
self._kwargs["block_size"] = 16
103+
104+
logger.debug(
105+
f"VLLM configuration for rerank model {self.model_family.model_name}: "
106+
f"max_model_len={self._kwargs.get('max_model_len')}, "
107+
f"max_num_batched_tokens={self._kwargs.get('max_num_batched_tokens')}"
108+
)
109+
70110
self._model = LLM(model=self._model_path, task="score", **self._kwargs)
71111
self._tokenizer = self._model.get_tokenizer()
72112

0 commit comments

Comments
 (0)