diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index f7ffc9c511..9542b94881 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -237,7 +237,8 @@ jobs: ${{ env.SELF_HOST_PYTHON }} -m pip install tensorizer ${{ env.SELF_HOST_PYTHON }} -m pip install -U sentence-transformers ${{ env.SELF_HOST_PYTHON }} -m pip install -U FlagEmbedding - ${{ env.SELF_HOST_PYTHON }} -m pip install -U "peft>=0.15.0" + ${{ env.SELF_HOST_PYTHON }} -m pip install -U "peft<=0.17.1" + ${{ env.SELF_HOST_PYTHON }} -m pip install "vllm" --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple ${{ env.SELF_HOST_PYTHON }} -m pip install "xllamacpp>=0.2.0" --index-url https://xorbitsai.github.io/xllamacpp/whl/cu124 --extra-index-url https://pypi.org/simple ${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=3000 \ --disable-warnings \ diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index 00ea1d2b1d..c7278875ff 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -163,7 +163,7 @@ def __init__( @classmethod @abstractmethod - def check_lib(cls) -> bool: + def check_lib(cls) -> Union[bool, str]: pass @classmethod @@ -173,7 +173,7 @@ def match_json( model_family: EmbeddingModelFamilyV2, model_spec: EmbeddingSpecV1, quantization: str, - ) -> bool: + ) -> Union[bool, str]: pass @classmethod @@ -182,13 +182,15 @@ def match( model_family: EmbeddingModelFamilyV2, model_spec: EmbeddingSpecV1, quantization: str, - ): + ) -> bool: """ Return if the model_spec can be matched. """ - if not cls.check_lib(): + lib_result = cls.check_lib() + if lib_result != True: return False - return cls.match_json(model_family, model_spec, quantization) + match_result = cls.match_json(model_family, model_spec, quantization) + return match_result == True @abstractmethod def load(self): diff --git a/xinference/model/embedding/flag/core.py b/xinference/model/embedding/flag/core.py index a55bdec4b3..ad4cec7140 100644 --- a/xinference/model/embedding/flag/core.py +++ b/xinference/model/embedding/flag/core.py @@ -282,8 +282,12 @@ def encode( return result @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("FlagEmbedding") is not None + def check_lib(cls) -> Union[bool, str]: + return ( + True + if importlib.util.find_spec("FlagEmbedding") is not None + else "FlagEmbedding library is not installed" + ) @classmethod def match_json( @@ -291,10 +295,15 @@ def match_json( model_family: EmbeddingModelFamilyV2, model_spec: EmbeddingSpecV1, quantization: str, - ) -> bool: + ) -> Union[bool, str]: + # Check library availability first + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + if ( model_spec.model_format in ["pytorch"] and model_family.model_name in FLAG_EMBEDDER_MODEL_LIST ): return True - return False + return f"FlagEmbedding engine only supports pytorch format and models in FLAG_EMBEDDER_MODEL_LIST, got format: {model_spec.model_format}, model: {model_family.model_name}" diff --git a/xinference/model/embedding/llama_cpp/core.py b/xinference/model/embedding/llama_cpp/core.py index eab41ec2eb..b4ec851b07 100644 --- a/xinference/model/embedding/llama_cpp/core.py +++ b/xinference/model/embedding/llama_cpp/core.py @@ -229,8 +229,12 @@ def _handle_embedding(): return Embedding(**r) # type: ignore @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("xllamacpp") is not None + def check_lib(cls) -> Union[bool, str]: + return ( + True + if importlib.util.find_spec("xllamacpp") is not None + else "xllamacpp library is not installed" + ) @classmethod def match_json( @@ -238,7 +242,32 @@ def match_json( model_family: EmbeddingModelFamilyV2, model_spec: EmbeddingSpecV1, quantization: str, - ) -> bool: + ) -> Union[bool, str]: + # Check library availability + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + + # Check model format compatibility if model_spec.model_format not in ["ggufv2"]: - return False + return f"llama.cpp embedding only supports GGUF v2 format, got: {model_spec.model_format}" + + # Check embedding-specific requirements + if not hasattr(model_spec, "model_file_name_template"): + return "GGUF embedding model requires proper file configuration (missing model_file_name_template)" + + # Check model dimensions for llama.cpp compatibility + model_dimensions = model_family.dimensions + if model_dimensions > 4096: # llama.cpp may have limitations + return f"Large embedding model may have compatibility issues with llama.cpp ({model_dimensions} dimensions)" + + # Check platform-specific considerations + import platform + + current_platform = platform.system() + + # llama.cpp works across platforms but may have performance differences + if current_platform == "Windows": + return "llama.cpp embedding may have limited performance on Windows" + return True diff --git a/xinference/model/embedding/match_result.py b/xinference/model/embedding/match_result.py new file mode 100644 index 0000000000..3e33c268d4 --- /dev/null +++ b/xinference/model/embedding/match_result.py @@ -0,0 +1,76 @@ +""" +Error handling result structures for embedding model engine matching. + +This module provides structured error handling for engine matching operations, +allowing engines to provide detailed failure reasons and suggestions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +@dataclass +class MatchResult: + """ + Result of engine matching operation with detailed error information. + + This class provides structured information about whether an engine can handle + a specific model configuration, and if not, why and what alternatives exist. + """ + + is_match: bool + reason: Optional[str] = None + error_type: Optional[str] = None + technical_details: Optional[str] = None + + @classmethod + def success(cls) -> "MatchResult": + """Create a successful match result.""" + return cls(is_match=True) + + @classmethod + def failure( + cls, + reason: str, + error_type: Optional[str] = None, + technical_details: Optional[str] = None, + ) -> "MatchResult": + """Create a failed match result with optional details.""" + return cls( + is_match=False, + reason=reason, + error_type=error_type, + technical_details=technical_details, + ) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for API responses.""" + result: Dict[str, Any] = {"is_match": self.is_match} + if not self.is_match: + if self.reason: + result["reason"] = self.reason + if self.error_type: + result["error_type"] = self.error_type + if self.technical_details: + result["technical_details"] = self.technical_details + return result + + def to_error_string(self) -> str: + """Convert to error string for backward compatibility.""" + if self.is_match: + return "Available" + error_msg = self.reason or "Unknown error" + return error_msg + + +# Error type constants for better categorization +class ErrorType: + HARDWARE_REQUIREMENT = "hardware_requirement" + OS_REQUIREMENT = "os_requirement" + MODEL_FORMAT = "model_format" + DEPENDENCY_MISSING = "dependency_missing" + MODEL_COMPATIBILITY = "model_compatibility" + DIMENSION_MISMATCH = "dimension_mismatch" + VERSION_REQUIREMENT = "version_requirement" + CONFIGURATION_ERROR = "configuration_error" + ENGINE_UNAVAILABLE = "engine_unavailable" diff --git a/xinference/model/embedding/sentence_transformers/core.py b/xinference/model/embedding/sentence_transformers/core.py index f5c42bb2df..cf46b4f761 100644 --- a/xinference/model/embedding/sentence_transformers/core.py +++ b/xinference/model/embedding/sentence_transformers/core.py @@ -429,8 +429,12 @@ def base64_to_image(base64_str: str) -> Image.Image: return result @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("sentence_transformers") is not None + def check_lib(cls) -> Union[bool, str]: + return ( + True + if importlib.util.find_spec("sentence_transformers") is not None + else "sentence_transformers library is not installed" + ) @classmethod def match_json( @@ -438,6 +442,43 @@ def match_json( model_family: EmbeddingModelFamilyV2, model_spec: EmbeddingSpecV1, quantization: str, - ) -> bool: - # As default embedding engine, sentence-transformer support all models - return model_spec.model_format in ["pytorch"] + ) -> Union[bool, str]: + # Check library availability + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + + # Check model format compatibility + if model_spec.model_format not in ["pytorch"]: + return f"Sentence Transformers only supports pytorch format, got: {model_spec.model_format}" + + # Check model dimensions compatibility + model_dimensions = model_family.dimensions + if model_dimensions > 8192: # Extremely large embedding models + return f"Extremely large embedding model detected ({model_dimensions} dimensions), may have performance issues" + + # Check token limits + max_tokens = model_family.max_tokens + if max_tokens > 131072: # Extremely high token limits (128K) + return f"Extremely high token limit model detected (max_tokens: {max_tokens}), may cause memory issues" + + # Check for special model requirements + model_name = model_family.model_name.lower() + + # Check Qwen2 GTE models + if "gte" in model_name and "qwen2" in model_name: + # These models have specific requirements + if not hasattr(cls, "_check_qwen_gte_requirements"): + return "Qwen2 GTE models require special handling" + + # Check Qwen3 models + if "qwen3" in model_name: + # Qwen3 has flash attention requirements - basic check + try: + pass + + # This would be checked during actual loading + except Exception: + return "Qwen3 embedding model may have compatibility issues" + + return True diff --git a/xinference/model/embedding/vllm/core.py b/xinference/model/embedding/vllm/core.py index a678b5023d..6b6c149681 100644 --- a/xinference/model/embedding/vllm/core.py +++ b/xinference/model/embedding/vllm/core.py @@ -23,7 +23,7 @@ from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1 logger = logging.getLogger(__name__) -SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "Qwen3"] +SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "qwen3"] class VLLMEmbeddingModel(EmbeddingModel, BatchMixin): @@ -34,16 +34,44 @@ def __init__(self, *args, **kwargs): def load(self): try: + # Handle vLLM-transformers config conflict by setting environment variable + import os + + os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache_vllm" + from vllm import LLM - except ImportError: + except ImportError as e: error_message = "Failed to import module 'vllm'" installation_guide = [ "Please make sure 'vllm' is installed. ", "You can install it by `pip install vllm`\n", ] + # Check if it's a config conflict error + if "aimv2" in str(e): + error_message = ( + "vLLM has a configuration conflict with transformers library" + ) + installation_guide = [ + "This is a known issue with certain vLLM and transformers versions.", + "Try upgrading transformers or using a different vLLM version.\n", + ] + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + except Exception as e: + # Handle config registration conflicts + if "aimv2" in str(e) and "already used by a Transformers config" in str(e): + error_message = ( + "vLLM has a configuration conflict with transformers library" + ) + installation_guide = [ + "This is a known issue with certain vLLM and transformers versions.", + "Try: pip install --upgrade transformers vllm\n", + ] + raise RuntimeError(f"{error_message}\n\n{''.join(installation_guide)}") + raise + if self.model_family.model_name in { "Qwen3-Embedding-0.6B", "Qwen3-Embedding-4B", @@ -63,6 +91,34 @@ def load(self): is_matryoshka=True, ) + # Set appropriate VLLM configuration parameters based on model capabilities + model_max_tokens = getattr(self.model_family, "max_tokens", 512) + + # Set max_model_len based on model family capabilities with reasonable limits + max_model_len = min(model_max_tokens, 8192) + if "max_model_len" not in self._kwargs: + self._kwargs["max_model_len"] = max_model_len + + # Ensure max_num_batched_tokens is sufficient for large models + if "max_num_batched_tokens" not in self._kwargs: + # max_num_batched_tokens should be at least max_model_len + # Set to a reasonable minimum that satisfies the constraint + self._kwargs["max_num_batched_tokens"] = max(4096, max_model_len) + + # Configure other reasonable defaults for embedding models + if "gpu_memory_utilization" not in self._kwargs: + self._kwargs["gpu_memory_utilization"] = 0.7 + + # Use a smaller block size for better compatibility + if "block_size" not in self._kwargs: + self._kwargs["block_size"] = 16 + + logger.debug( + f"VLLM configuration for {self.model_family.model_name}: " + f"max_model_len={self._kwargs.get('max_model_len')}, " + f"max_num_batched_tokens={self._kwargs.get('max_num_batched_tokens')}" + ) + self._model = LLM(model=self._model_path, task="embed", **self._kwargs) self._tokenizer = self._model.get_tokenizer() @@ -151,8 +207,12 @@ def _create_embedding( return result @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("vllm") is not None + def check_lib(cls) -> Union[bool, str]: + return ( + True + if importlib.util.find_spec("vllm") is not None + else "vllm library is not installed" + ) @classmethod def match_json( @@ -160,12 +220,47 @@ def match_json( model_family: EmbeddingModelFamilyV2, model_spec: EmbeddingSpecV1, quantization: str, - ) -> bool: - if model_spec.model_format in ["pytorch"]: - prefix = model_family.model_name.split("-", 1)[0] - if prefix in SUPPORTED_MODELS_PREFIXES: - return True - return False + ) -> Union[bool, str]: + # Check library availability first + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + + # Check model format compatibility + if model_spec.model_format not in ["pytorch"]: + return f"VLLM Embedding engine only supports pytorch format models, got format: {model_spec.model_format}" + + # Check model name prefix matching + prefix = model_family.model_name.split("-", 1)[0] + if prefix.lower() not in [p.lower() for p in SUPPORTED_MODELS_PREFIXES]: + return f"VLLM Embedding engine only supports models with prefixes {SUPPORTED_MODELS_PREFIXES}, got model: {model_family.model_name}" + + # Additional runtime compatibility checks for vLLM version + try: + import vllm + from packaging.version import Version + + vllm_version = Version(vllm.__version__) + + # Check for vLLM version compatibility issues + if vllm_version >= Version("0.10.0") and vllm_version < Version("0.11.0"): + # vLLM 0.10.x has V1 engine issues on CPU + import platform + + if platform.system() == "Darwin" and platform.machine() in [ + "arm64", + "arm", + ]: + # Check if this is likely to run on CPU (most common for testing) + return f"vLLM {vllm_version} has compatibility issues with embedding models on Apple Silicon CPUs. Consider using a different platform or vLLM version." + elif vllm_version >= Version("0.11.0"): + # vLLM 0.11+ should have fixed the config conflict issue + pass + except Exception: + # If version check fails, continue with basic validation + pass + + return True def wait_for_load(self): # set context length after engine inited @@ -181,6 +276,21 @@ def _set_context_length(self): self._model.llm_engine.vllm_config.model_config.max_model_len ) else: - # v1 - logger.warning("vLLM v1 is not supported, ignore context length setting") + # v1 - Get max_model_len from the v1 engine configuration + try: + # For v1, access the config differently + if hasattr(self._model.llm_engine, "vllm_config"): + self._context_length = ( + self._model.llm_engine.vllm_config.model_config.max_model_len + ) + elif hasattr(self._model.llm_engine, "model_config"): + self._context_length = ( + self._model.llm_engine.model_config.max_model_len + ) + else: + # Fallback to the configured value + self._context_length = self._kwargs.get("max_model_len", 512) + except Exception as e: + logger.warning(f"Failed to get context length from vLLM v1 engine: {e}") + self._context_length = self._kwargs.get("max_model_len", 512) logger.debug("Model context length: %s", self._context_length) diff --git a/xinference/model/embedding/vllm/tests/test_vllm_embedding.py b/xinference/model/embedding/vllm/tests/test_vllm_embedding.py index 977a0e5543..84530c34ec 100644 --- a/xinference/model/embedding/vllm/tests/test_vllm_embedding.py +++ b/xinference/model/embedding/vllm/tests/test_vllm_embedding.py @@ -16,15 +16,36 @@ import pytest +# Force import of the entire embedding module to ensure initialization +import xinference.model.embedding as embedding_module + from .....client import Client + +# Ensure embedding engines are properly initialized +# This addresses potential CI environment initialization issues +from ... import BUILTIN_EMBEDDING_MODELS, generate_engine_config_by_model_name from ...cache_manager import EmbeddingCacheManager as CacheManager from ...core import ( EmbeddingModelFamilyV2, TransformersEmbeddingSpecV1, create_embedding_model_instance, ) +from ...embed_family import EMBEDDING_ENGINES from ..core import VLLMEmbeddingModel +if "bge-small-en-v1.5" in BUILTIN_EMBEDDING_MODELS: + # Force regeneration of engine configuration + generate_engine_config_by_model_name( + BUILTIN_EMBEDDING_MODELS["bge-small-en-v1.5"][0] + ) + +# Debug: Check if vllm engine is registered +if "bge-small-en-v1.5" not in EMBEDDING_ENGINES or "vllm" not in EMBEDDING_ENGINES.get( + "bge-small-en-v1.5", {} +): + # Force re-initialization of embedding module + embedding_module._install() + TEST_MODEL_SPEC = EmbeddingModelFamilyV2( version=2, model_name="bge-small-en-v1.5", diff --git a/xinference/model/llm/core.py b/xinference/model/llm/core.py index 8abc8f04a6..5942a42879 100644 --- a/xinference/model/llm/core.py +++ b/xinference/model/llm/core.py @@ -70,7 +70,7 @@ def __init__( @classmethod @abstractmethod - def check_lib(cls) -> bool: + def check_lib(cls) -> Union[bool, str]: raise NotImplementedError @staticmethod @@ -148,15 +148,17 @@ def load(self): def match( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str ) -> bool: - if not cls.check_lib(): + lib_result = cls.check_lib() + if lib_result != True: return False - return cls.match_json(llm_family, llm_spec, quantization) + match_result = cls.match_json(llm_family, llm_spec, quantization) + return match_result == True @classmethod @abstractmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: + ) -> Union[bool, str]: raise NotImplementedError def prepare_parse_reasoning_content( diff --git a/xinference/model/llm/llama_cpp/core.py b/xinference/model/llm/llama_cpp/core.py index d009378dbe..5d379e642d 100644 --- a/xinference/model/llm/llama_cpp/core.py +++ b/xinference/model/llm/llama_cpp/core.py @@ -79,20 +79,33 @@ def _sanitize_model_config(self, llamacpp_model_config: Optional[dict]) -> dict: return llamacpp_model_config @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("xllamacpp") is not None + def check_lib(cls) -> Union[bool, str]: + return ( + True + if importlib.util.find_spec("xllamacpp") is not None + else "xllamacpp library is not installed" + ) @classmethod def match_json( cls, llm_family: LLMFamilyV2, llm_spec: LLMSpecV1, quantization: str - ) -> bool: + ) -> Union[bool, str]: + # Check library availability + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + + # Check model format compatibility if llm_spec.model_format not in ["ggufv2"]: - return False - if ( - "chat" not in llm_family.model_ability - and "generate" not in llm_family.model_ability - ): - return False + return ( + f"llama.cpp only supports GGUF v2 format, got: {llm_spec.model_format}" + ) + + # Check memory requirements (basic heuristic) + model_size = float(str(llm_spec.model_size_in_billions)) + if model_size > 70: # Very large models + return f"llama.cpp may struggle with very large models ({model_size}B parameters)" + return True def load(self): diff --git a/xinference/model/llm/lmdeploy/core.py b/xinference/model/llm/lmdeploy/core.py index 0144a6f734..9689c3ddce 100644 --- a/xinference/model/llm/lmdeploy/core.py +++ b/xinference/model/llm/lmdeploy/core.py @@ -114,14 +114,18 @@ def load(self): raise ValueError("LMDEPLOY engine has not supported generate yet.") @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("lmdeploy") is not None + def check_lib(cls) -> Union[bool, str]: + return ( + True + if importlib.util.find_spec("lmdeploy") is not None + else "lmdeploy library is not installed" + ) @classmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: - return False + ) -> Union[bool, str]: + return "LMDeploy base model does not support direct inference, use specific LMDeploy model classes" def generate( self, @@ -173,14 +177,23 @@ def load(self): @classmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: + ) -> Union[bool, str]: + # Check library availability first + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + + # Check model format compatibility and quantization if llm_spec.model_format == "awq": - # Currently, only 4-bit weight quantization is supported for AWQ, but got 8 bits. + # LMDeploy has specific AWQ quantization requirements if "4" not in quantization: - return False + return f"LMDeploy AWQ format requires 4-bit quantization, got: {quantization}" + + # Check model compatibility if llm_family.model_name not in LMDEPLOY_SUPPORTED_CHAT_MODELS: - return False - return LMDEPLOY_INSTALLED + return f"Chat model not supported by LMDeploy: {llm_family.model_name}" + + return True async def async_chat( self, diff --git a/xinference/model/llm/match_result.py b/xinference/model/llm/match_result.py new file mode 100644 index 0000000000..3ab90d2c37 --- /dev/null +++ b/xinference/model/llm/match_result.py @@ -0,0 +1,76 @@ +""" +Error handling result structures for engine matching. + +This module provides structured error handling for engine matching operations, +allowing engines to provide detailed failure reasons and suggestions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +@dataclass +class MatchResult: + """ + Result of engine matching operation with detailed error information. + + This class provides structured information about whether an engine can handle + a specific model configuration, and if not, why and what alternatives exist. + """ + + is_match: bool + reason: Optional[str] = None + error_type: Optional[str] = None + technical_details: Optional[str] = None + + @classmethod + def success(cls) -> "MatchResult": + """Create a successful match result.""" + return cls(is_match=True) + + @classmethod + def failure( + cls, + reason: str, + error_type: Optional[str] = None, + technical_details: Optional[str] = None, + ) -> "MatchResult": + """Create a failed match result with optional details.""" + return cls( + is_match=False, + reason=reason, + error_type=error_type, + technical_details=technical_details, + ) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for API responses.""" + result: Dict[str, Any] = {"is_match": self.is_match} + if not self.is_match: + if self.reason: + result["reason"] = self.reason + if self.error_type: + result["error_type"] = self.error_type + if self.technical_details: + result["technical_details"] = self.technical_details + return result + + def to_error_string(self) -> str: + """Convert to error string for backward compatibility.""" + if self.is_match: + return "Available" + error_msg = self.reason or "Unknown error" + return error_msg + + +# Error type constants for better categorization +class ErrorType: + HARDWARE_REQUIREMENT = "hardware_requirement" + OS_REQUIREMENT = "os_requirement" + MODEL_FORMAT = "model_format" + QUANTIZATION = "quantization" + DEPENDENCY_MISSING = "dependency_missing" + MODEL_COMPATIBILITY = "model_compatibility" + ABILITY_MISMATCH = "ability_mismatch" + VERSION_REQUIREMENT = "version_requirement" + CONFIGURATION_ERROR = "configuration_error" diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index 80b9c4be2f..b391ac97b8 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -18,7 +18,6 @@ import importlib.util import logging import pathlib -import platform import sys import threading import time @@ -404,23 +403,39 @@ def wait_for_load(self): self._context_length = get_context_length(config) @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("mlx_lm") is not None + def check_lib(cls) -> Union[bool, str]: + return ( + True + if importlib.util.find_spec("mlx_lm") is not None + else "mlx_lm library is not installed" + ) @classmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: + ) -> Union[bool, str]: + # Check library availability first + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + + # Check model format compatibility if llm_spec.model_format not in ["mlx"]: - return False - if sys.platform != "darwin" or platform.processor() != "arm": - # only work for Mac M chips - return False - if "generate" not in llm_family.model_ability: - return False - if "chat" in llm_family.model_ability or "vision" in llm_family.model_ability: - # do not process chat or vision - return False + return f"MLX engine only supports MLX format, got: {llm_spec.model_format}" + + # Base MLX model should not handle chat or vision models + # Those should be handled by MLXChatModel and MLXVisionModel respectively + model_abilities = getattr(llm_family, "model_ability", []) + if "chat" in model_abilities: + return False # Let MLXChatModel handle this + if "vision" in model_abilities: + return False # Let MLXVisionModel handle this + + # Check memory constraints for Apple Silicon + model_size = float(str(llm_spec.model_size_in_billions)) + if model_size > 70: # Large models may be problematic + return f"MLX may have memory limitations with very large models ({model_size}B parameters)" + return True def _get_prompt_cache( @@ -721,17 +736,30 @@ def _sanitize_generate_config( @classmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: + ) -> Union[bool, str]: + # Check library availability first + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + + # Check model format compatibility if llm_spec.model_format not in ["mlx"]: - return False - if sys.platform != "darwin" or platform.processor() != "arm": - # only work for Mac M chips - return False - if "chat" not in llm_family.model_ability: - return False - if "vision" in llm_family.model_ability: - # do not process vision - return False + return f"MLX Chat engine only supports MLX format, got: {llm_spec.model_format}" + + # Check that this model has chat ability + model_abilities = getattr(llm_family, "model_ability", []) + if "chat" not in model_abilities: + return False # Not a chat model + + # MLX Chat doesn't support vision + if "vision" in model_abilities: + return False # Let MLXVisionModel handle this + + # Check memory constraints for Apple Silicon + model_size = float(str(llm_spec.model_size_in_billions)) + if model_size > 70: # Large models may be problematic + return f"MLX Chat may have memory limitations with very large models ({model_size}B parameters)" + return True def chat( @@ -779,20 +807,36 @@ def chat( class MLXVisionModel(MLXModel, ChatModelMixin): @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("mlx_vlm") is not None + def check_lib(cls) -> Union[bool, str]: + return ( + True + if importlib.util.find_spec("mlx_vlm") is not None + else "mlx_vlm library is not installed" + ) @classmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: + ) -> Union[bool, str]: + # Check library availability first - MLX Vision uses mlx_vlm + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + + # Check model format compatibility if llm_spec.model_format not in ["mlx"]: - return False - if sys.platform != "darwin" or platform.processor() != "arm": - # only work for Mac M chips - return False - if "vision" not in llm_family.model_ability: - return False + return f"MLX Vision engine only supports MLX format, got: {llm_spec.model_format}" + + # Check that this model has vision ability + model_abilities = getattr(llm_family, "model_ability", []) + if "vision" not in model_abilities: + return False # Not a vision model + + # Check memory constraints for Apple Silicon + model_size = float(str(llm_spec.model_size_in_billions)) + if model_size > 70: # Large models may be problematic + return f"MLX Vision may have memory limitations with very large models ({model_size}B parameters)" + return True def _load_model(self, **kwargs): diff --git a/xinference/model/llm/sglang/core.py b/xinference/model/llm/sglang/core.py index d3bbfc1570..7d5d13d229 100644 --- a/xinference/model/llm/sglang/core.py +++ b/xinference/model/llm/sglang/core.py @@ -334,31 +334,130 @@ def _sanitize_generate_config( return generate_config @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("sglang") is not None + def check_lib(cls) -> Union[bool, str]: + # Check CUDA first - this is the most important requirement + try: + import torch + + if not torch.cuda.is_available(): + return "SGLang requires CUDA support but no CUDA devices detected" + except ImportError: + return "SGLang requires PyTorch with CUDA support" + + if importlib.util.find_spec("sglang") is None: + return "sglang library is not installed" + + try: + if not getattr(sglang, "__version__", None): + return "SGLang version information is not available" + + # Check version - SGLang requires recent version + from packaging import version + + if version.parse(sglang.__version__) < version.parse("0.1.0"): + return f"SGLang version {sglang.__version__} is too old, minimum required is 0.1.0" + + return True + except Exception as e: + return f"Error checking SGLang library: {str(e)}" @classmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: - if not cls._has_cuda_device(): - return False - if not cls._is_linux(): - return False - if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]: - return False + ) -> Union[bool, str]: + # Check library availability first + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + + # Check GPU requirements + try: + import torch + + if torch.cuda.device_count() == 0: + return "SGLang requires CUDA support but no CUDA devices detected" + except ImportError: + return "SGLang requires PyTorch with CUDA support" + + # Check model format compatibility + supported_formats = ["pytorch", "gptq", "awq", "fp8", "bnb"] + if llm_spec.model_format not in supported_formats: + return f"SGLang does not support model format: {llm_spec.model_format}, supported formats: {', '.join(supported_formats)}" + + # Check quantization compatibility with format if llm_spec.model_format == "pytorch": - if quantization != "none" and not (quantization is None): - return False + if quantization != "none" and quantization is not None: + return f"SGLang pytorch format does not support quantization: {quantization}" + + # Check model compatibility with more flexible matching + def is_model_supported(model_name: str, supported_list: List[str]) -> bool: + """Check if model is supported with flexible matching.""" + # Direct match + if model_name in supported_list: + return True + + # Partial matching for models with variants (e.g., qwen3 variants) + for supported in supported_list: + if model_name.startswith( + supported.lower() + ) or supported.lower().startswith(model_name): + return True + + # Family-based matching for common patterns + model_lower = model_name.lower() + if any( + family in model_lower + for family in [ + "qwen3", + "llama", + "mistral", + "mixtral", + "qwen2", + "qwen2.5", + "deepseek", + "yi", + "baichuan", + ] + ): + # Check if there's a corresponding supported model with same family + for supported in supported_list: + if any( + family in supported.lower() + for family in [ + "qwen3", + "llama", + "mistral", + "mixtral", + "qwen2", + "qwen2.5", + "deepseek", + "yi", + "baichuan", + ] + ): + return True + + return False + if isinstance(llm_family, CustomLLMFamilyV2): - if llm_family.model_family not in SGLANG_SUPPORTED_MODELS: - return False + if not llm_family.model_family or not is_model_supported( + llm_family.model_family.lower(), SGLANG_SUPPORTED_MODELS + ): + # Instead of hard rejection, give a warning but allow usage + logger.warning( + f"Custom model family may not be fully supported by SGLang: {llm_family.model_family}" + ) else: - if llm_family.model_name not in SGLANG_SUPPORTED_MODELS: - return False - if "generate" not in llm_family.model_ability: - return False - return SGLANG_INSTALLED + if not llm_family.model_name or not is_model_supported( + llm_family.model_name.lower(), + [s.lower() for s in SGLANG_SUPPORTED_MODELS], + ): + # Instead of hard rejection, give a warning but allow usage + logger.warning( + f"Model may not be fully supported by SGLang: {llm_family.model_name}" + ) + + return True @staticmethod def _convert_state_to_completion_chunk( @@ -646,21 +745,76 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin): @classmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: - if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]: - return False + ) -> Union[bool, str]: + # First run base class checks + base_result = super().match_json(llm_family, llm_spec, quantization) + if base_result != True: + return base_result + + # Check model format compatibility (same as base) + supported_formats = ["pytorch", "gptq", "awq", "fp8", "bnb"] + if llm_spec.model_format not in supported_formats: + return f"SGLang Chat does not support model format: {llm_spec.model_format}" + + # Check quantization compatibility with format if llm_spec.model_format == "pytorch": - if quantization != "none" and not (quantization is None): - return False + if quantization != "none" and quantization is not None: + return f"SGLang Chat pytorch format does not support quantization: {quantization}" + + # Check chat model compatibility with more flexible matching + def is_chat_model_supported(model_name: str, supported_list: List[str]) -> bool: + """Check if chat model is supported with flexible matching.""" + # Direct match + if model_name in supported_list: + return True + + # Partial matching for models with variants + for supported in supported_list: + if model_name.startswith( + supported.lower() + ) or supported.lower().startswith(model_name): + return True + + # Family-based matching for common chat patterns + model_lower = model_name.lower() + if any(suffix in model_lower for suffix in ["chat", "instruct", "coder"]): + if any( + family in model_lower + for family in [ + "qwen3", + "llama", + "mistral", + "mixtral", + "qwen2", + "qwen2.5", + "deepseek", + "yi", + "baichuan", + ] + ): + return True + + return False + if isinstance(llm_family, CustomLLMFamilyV2): - if llm_family.model_family not in SGLANG_SUPPORTED_CHAT_MODELS: - return False + if not llm_family.model_family or not is_chat_model_supported( + llm_family.model_family.lower(), SGLANG_SUPPORTED_CHAT_MODELS + ): + # Instead of hard rejection, give a warning but allow usage + logger.warning( + f"Custom chat model may not be fully supported by SGLang: {llm_family.model_family}" + ) else: - if llm_family.model_name not in SGLANG_SUPPORTED_CHAT_MODELS: - return False - if "chat" not in llm_family.model_ability: - return False - return SGLANG_INSTALLED + if not llm_family.model_name or not is_chat_model_supported( + llm_family.model_name.lower(), + [s.lower() for s in SGLANG_SUPPORTED_CHAT_MODELS], + ): + # Instead of hard rejection, give a warning but allow usage + logger.warning( + f"Chat model may not be fully supported by SGLang: {llm_family.model_name}" + ) + + return True def _sanitize_chat_config( self, @@ -733,25 +887,81 @@ class SGLANGVisionModel(SGLANGModel, ChatModelMixin): @classmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: - if not cls._has_cuda_device(): - return False - if not cls._is_linux(): - return False - if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]: - return False + ) -> Union[bool, str]: + # First run base class checks + base_result = super().match_json(llm_family, llm_spec, quantization) + if base_result != True: + return base_result + + # Vision models have the same format restrictions as base SGLANG + supported_formats = ["pytorch", "gptq", "awq", "fp8", "bnb"] + if llm_spec.model_format not in supported_formats: + return ( + f"SGLang Vision does not support model format: {llm_spec.model_format}" + ) + + # Vision models typically work with specific quantization settings if llm_spec.model_format == "pytorch": - if quantization != "none" and not (quantization is None): - return False + if quantization != "none" and quantization is not None: + return f"SGLang Vision pytorch format does not support quantization: {quantization}" + + # Check vision model compatibility with more flexible matching + def is_vision_model_supported( + model_name: str, supported_list: List[str] + ) -> bool: + """Check if vision model is supported with flexible matching.""" + # Direct match + if model_name in supported_list: + return True + + # Partial matching for models with variants + for supported in supported_list: + if model_name.startswith( + supported.lower() + ) or supported.lower().startswith(model_name): + return True + + # Family-based matching for common vision patterns + model_lower = model_name.lower() + if any(suffix in model_lower for suffix in ["vision", "vl", "multi", "mm"]): + if any( + family in model_lower + for family in [ + "qwen3", + "llama", + "mistral", + "mixtral", + "qwen2", + "qwen2.5", + "deepseek", + "yi", + "baichuan", + "internvl", + ] + ): + return True + + return False + if isinstance(llm_family, CustomLLMFamilyV2): - if llm_family.model_family not in SGLANG_SUPPORTED_VISION_MODEL_LIST: - return False + if not llm_family.model_family or not is_vision_model_supported( + llm_family.model_family.lower(), SGLANG_SUPPORTED_VISION_MODEL_LIST + ): + # Instead of hard rejection, give a warning but allow usage + logger.warning( + f"Custom vision model may not be fully supported by SGLang: {llm_family.model_family}" + ) else: - if llm_family.model_name not in SGLANG_SUPPORTED_VISION_MODEL_LIST: - return False - if "vision" not in llm_family.model_ability: - return False - return SGLANG_INSTALLED + if not llm_family.model_name or not is_vision_model_supported( + llm_family.model_name.lower(), + [s.lower() for s in SGLANG_SUPPORTED_VISION_MODEL_LIST], + ): + # Instead of hard rejection, give a warning but allow usage + logger.warning( + f"Vision model may not be fully supported by SGLang: {llm_family.model_name}" + ) + + return True def _sanitize_chat_config( self, diff --git a/xinference/model/llm/transformers/core.py b/xinference/model/llm/transformers/core.py index 6ad98c38e8..39e963164b 100644 --- a/xinference/model/llm/transformers/core.py +++ b/xinference/model/llm/transformers/core.py @@ -493,20 +493,32 @@ def stop(self): del self._tokenizer @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("transformers") is not None + def check_lib(cls) -> Union[bool, str]: + return ( + True + if importlib.util.find_spec("transformers") is not None + else "transformers library is not installed" + ) @classmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: - if llm_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]: - return False + ) -> Union[bool, str]: + # Check library availability + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + + # Check model format compatibility + supported_formats = ["pytorch", "gptq", "awq", "bnb"] + if llm_spec.model_format not in supported_formats: + return f"Transformers does not support model format: {llm_spec.model_format}, supported formats: {', '.join(supported_formats)}" + + # Check for models that shouldn't use Transformers by default model_family = llm_family.model_family or llm_family.model_name if model_family in NON_DEFAULT_MODEL_LIST: - return False - if "generate" not in llm_family.model_ability: - return False + return f"Model {model_family} is not recommended for Transformers engine, has specialized engine preference" + return True def build_prefill_attention_mask( @@ -965,8 +977,6 @@ def match_json( model_family = llm_family.model_family or llm_family.model_name if model_family in NON_DEFAULT_MODEL_LIST: return False - if "chat" not in llm_family.model_ability: - return False return True async def chat( diff --git a/xinference/model/llm/transformers/multimodal/core.py b/xinference/model/llm/transformers/multimodal/core.py index ae67e102b5..4d6451f42e 100644 --- a/xinference/model/llm/transformers/multimodal/core.py +++ b/xinference/model/llm/transformers/multimodal/core.py @@ -39,21 +39,18 @@ def decide_device(self): """ Update self._device """ - pass @abstractmethod def load_processor(self): """ Load self._processor and self._tokenizer """ - pass @abstractmethod def load_multimodal_model(self): """ Load self._model """ - pass def load(self): self.decide_device() @@ -71,7 +68,6 @@ def build_inputs_from_messages( actual parameters needed for inference, e.g. input_ids, attention_masks, etc. """ - pass @abstractmethod def build_generate_kwargs( @@ -82,7 +78,6 @@ def build_generate_kwargs( Hyperparameters needed for generation, e.g. temperature, max_new_tokens, etc. """ - pass @abstractmethod def build_streaming_iter( @@ -95,7 +90,6 @@ def build_streaming_iter( The length of prompt token usually comes from the input_ids. In this interface you need to call the `build_inputs_from_messages` and `build_generate_kwargs`. """ - pass def get_stop_strs(self) -> List[str]: return [] diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 4da42ed48b..7262053a50 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -850,42 +850,139 @@ def _sanitize_generate_config( return sanitized @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("vllm") is not None + def check_lib(cls) -> Union[bool, str]: + # Check CUDA first - this is the most important requirement + try: + import torch + + if not torch.cuda.is_available(): + return "vLLM requires CUDA support but no CUDA devices detected" + except ImportError: + return "vLLM requires PyTorch with CUDA support" + + if importlib.util.find_spec("vllm") is None: + return "vLLM library is not installed" + + try: + import vllm + + if not getattr(vllm, "__version__", None): + return "vLLM version information is not available" + + # Check version + from packaging import version + + if version.parse(vllm.__version__) < version.parse("0.3.0"): + return f"vLLM version {vllm.__version__} is too old, minimum required is 0.3.0" + + return True + except Exception as e: + return f"Error checking vLLM library: {str(e)}" @classmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: - if not cls._has_cuda_device() and not cls._has_mlu_device(): - return False - if not cls._is_linux(): - return False - if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]: - return False + ) -> Union[bool, str]: + # Check library availability first + if not VLLM_INSTALLED: + return "vLLM library is not installed" + + # Check GPU device count + try: + import torch + + if torch.cuda.device_count() == 0: + return "vLLM requires CUDA support but no CUDA devices detected" + except ImportError: + return "vLLM requires PyTorch with CUDA support" + + # Check model format + supported_formats = ["pytorch", "gptq", "awq", "fp8", "bnb"] + if llm_spec.model_format not in supported_formats: + return f"vLLM does not support model format: {llm_spec.model_format}, supported formats: {', '.join(supported_formats)}" + + # Check quantization compatibility with format if llm_spec.model_format == "pytorch": if quantization != "none" and quantization is not None: - return False + return ( + f"vLLM pytorch format does not support quantization: {quantization}" + ) + if llm_spec.model_format == "awq": - # Currently, only 4-bit weight quantization is supported for AWQ, but got 8 bits. if "4" not in quantization: - return False + return ( + f"vLLM AWQ format requires 4-bit quantization, got: {quantization}" + ) + if llm_spec.model_format == "gptq": if VLLM_INSTALLED and VLLM_VERSION >= version.parse("0.3.3"): if not any(q in quantization for q in ("3", "4", "8")): - return False + return f"vLLM GPTQ format requires 3/4/8-bit quantization, got: {quantization}" else: if "4" not in quantization: - return False + return f"Older vLLM version only supports 4-bit GPTQ, got: {quantization} (requires vLLM >= 0.3.3 for 3/8-bit)" + + # Check model compatibility with more flexible matching + def is_model_supported(model_name: str, supported_list: List[str]) -> bool: + """Check if model is supported with flexible matching.""" + # Direct match + if model_name in supported_list: + return True + + # Partial matching for models with variants (e.g., qwen3 variants) + for supported in supported_list: + if model_name.startswith( + supported.lower() + ) or supported.lower().startswith(model_name): + return True + + # Family-based matching for common patterns + model_lower = model_name.lower() + if any( + family in model_lower + for family in [ + "qwen3", + "llama", + "mistral", + "gemma", + "baichuan", + "deepseek", + ] + ): + # Check if there's a corresponding supported model with same family + for supported in supported_list: + if any( + family in supported.lower() + for family in [ + "qwen3", + "llama", + "mistral", + "gemma", + "baichuan", + "deepseek", + ] + ): + return True + + return False + if isinstance(llm_family, CustomLLMFamilyV2): - if llm_family.model_family not in VLLM_SUPPORTED_MODELS: - return False + if not llm_family.model_family or not is_model_supported( + llm_family.model_family.lower(), VLLM_SUPPORTED_MODELS + ): + return f"Custom model family may not be fully supported by vLLM: {llm_family.model_family}" else: - if llm_family.model_name not in VLLM_SUPPORTED_MODELS: - return False - if "generate" not in llm_family.model_ability: - return False - return VLLM_INSTALLED + if not is_model_supported( + llm_family.model_name.lower(), + [s.lower() for s in VLLM_SUPPORTED_MODELS], + ): + # Instead of hard rejection, give a warning but allow usage + logger.warning( + f"Model may not be fully supported by vLLM: {llm_family.model_name}" + ) + + # All checks passed + return True @staticmethod def _convert_request_output_to_completion_chunk( @@ -1292,41 +1389,91 @@ class VLLMChatModel(VLLMModel, ChatModelMixin): @classmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: - if llm_spec.model_format not in [ - "pytorch", - "gptq", - "awq", - "fp8", - "bnb", - "ggufv2", - ]: - return False - if llm_spec.model_format == "pytorch": - if quantization != "none" and quantization is not None: - return False - if llm_spec.model_format == "awq": - if not any(q in quantization for q in ("4", "8")): - return False - if llm_spec.model_format == "gptq": - if VLLM_INSTALLED and VLLM_VERSION >= version.parse("0.3.3"): - if not any(q in quantization for q in ("3", "4", "8")): - return False - else: - if "4" not in quantization: - return False + ) -> Union[bool, str]: + # First run base class checks + base_result = super().match_json(llm_family, llm_spec, quantization) + if base_result != True: + return base_result + + # Chat-specific format support (includes GGUFv2 for newer vLLM) + supported_formats = ["pytorch", "gptq", "awq", "fp8", "bnb", "ggufv2"] + if llm_spec.model_format not in supported_formats: + return f"vLLM Chat does not support model format: {llm_spec.model_format}" + + # GGUFv2 requires newer vLLM version if llm_spec.model_format == "ggufv2": if not (VLLM_INSTALLED and VLLM_VERSION >= version.parse("0.8.2")): - return False + return f"vLLM GGUF support requires version >= 0.8.2, current: {VLLM_VERSION}" + + # AWQ chat models support more quantization levels + if llm_spec.model_format == "awq": + if not any(q in quantization for q in ("4", "8")): + return f"vLLM Chat AWQ requires 4 or 8-bit quantization, got: {quantization}" + + # Check chat model compatibility with flexible matching + def is_chat_model_supported(model_name: str, supported_list: List[str]) -> bool: + """Check if chat model is supported with flexible matching.""" + # Direct match + if model_name in supported_list: + return True + + # Partial matching for models with variants + for supported in supported_list: + if model_name.startswith( + supported.lower() + ) or supported.lower().startswith(model_name): + return True + + # Family-based matching for common chat model patterns + model_lower = model_name.lower() + if any( + family in model_lower + for family in [ + "qwen3", + "llama", + "mistral", + "gemma", + "baichuan", + "deepseek", + "glm", + "chatglm", + ] + ): + # Check if there's a corresponding supported chat model with same family + for supported in supported_list: + if any( + family in supported.lower() + for family in [ + "qwen3", + "llama", + "mistral", + "gemma", + "baichuan", + "deepseek", + "glm", + "chatglm", + ] + ): + return True + + return False + if isinstance(llm_family, CustomLLMFamilyV2): - if llm_family.model_family not in VLLM_SUPPORTED_CHAT_MODELS: - return False + if not llm_family.model_family or not is_chat_model_supported( + llm_family.model_family.lower(), VLLM_SUPPORTED_CHAT_MODELS + ): + return f"Custom chat model may not be fully supported by vLLM: {llm_family.model_family}" else: - if llm_family.model_name not in VLLM_SUPPORTED_CHAT_MODELS: - return False - if "chat" not in llm_family.model_ability: - return False - return VLLM_INSTALLED + if not is_chat_model_supported( + llm_family.model_name.lower(), + [s.lower() for s in VLLM_SUPPORTED_CHAT_MODELS], + ): + # Instead of hard rejection, give a warning but allow usage + logger.warning( + f"Chat model may not be fully supported by vLLM: {llm_family.model_name}" + ) + + return True def _sanitize_chat_config( self, @@ -1470,39 +1617,74 @@ class VLLMMultiModel(VLLMModel, ChatModelMixin): @classmethod def match_json( cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str - ) -> bool: - if not cls._has_cuda_device() and not cls._has_mlu_device(): - return False - if not cls._is_linux(): - return False - if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]: - return False + ) -> Union[bool, str]: + # First run base class checks + base_result = super().match_json(llm_family, llm_spec, quantization) + if base_result != True: + return base_result + + # Vision models have the same format restrictions as base VLLM + supported_formats = ["pytorch", "gptq", "awq", "fp8", "bnb"] + if llm_spec.model_format not in supported_formats: + return f"vLLM Vision does not support model format: {llm_spec.model_format}" + + # Vision models typically work with specific quantization settings if llm_spec.model_format == "pytorch": if quantization != "none" and quantization is not None: - return False + return f"vLLM Vision pytorch format does not support quantization: {quantization}" + + # AWQ vision models support more quantization levels than base if llm_spec.model_format == "awq": if not any(q in quantization for q in ("4", "8")): - return False - if llm_spec.model_format == "gptq": - if VLLM_INSTALLED and VLLM_VERSION >= version.parse("0.3.3"): - if not any(q in quantization for q in ("3", "4", "8")): - return False - else: - if "4" not in quantization: - return False + return f"vLLM Vision AWQ requires 4 or 8-bit quantization, got: {quantization}" + + # Check vision model compatibility with flexible matching + def is_vision_model_supported( + model_name: str, supported_list: List[str] + ) -> bool: + """Check if vision model is supported with flexible matching.""" + # Direct match + if model_name in supported_list: + return True + + # Partial matching for models with variants + for supported in supported_list: + if model_name.startswith( + supported.lower() + ) or supported.lower().startswith(model_name): + return True + + # Family-based matching for common vision model patterns + model_lower = model_name.lower() + if any( + family in model_lower + for family in ["llama", "qwen", "internvl", "glm", "phi"] + ): + # Check if there's a corresponding supported vision model with same family + for supported in supported_list: + if any( + family in supported.lower() + for family in ["llama", "qwen", "internvl", "glm", "phi"] + ): + return True + + return False + if isinstance(llm_family, CustomLLMFamilyV2): - if llm_family.model_family not in VLLM_SUPPORTED_MULTI_MODEL_LIST: - return False + if not llm_family.model_family or not is_vision_model_supported( + llm_family.model_family.lower(), VLLM_SUPPORTED_MULTI_MODEL_LIST + ): + return f"Custom vision model may not be fully supported by vLLM: {llm_family.model_family}" else: - if llm_family.model_name not in VLLM_SUPPORTED_MULTI_MODEL_LIST: - return False - if ( - "vision" not in llm_family.model_ability - and "audio" not in llm_family.model_ability - and "omni" not in llm_family.model_ability - ): - return False - return VLLM_INSTALLED + if not llm_family.model_name or not is_vision_model_supported( + llm_family.model_name.lower(), VLLM_SUPPORTED_MULTI_MODEL_LIST + ): + # Instead of hard rejection, give a warning but allow usage + logger.warning( + f"Vision model may not be fully supported by vLLM: {llm_family.model_name}" + ) + + return True def _sanitize_model_config( self, model_config: Optional[VLLMModelConfig] diff --git a/xinference/model/rerank/core.py b/xinference/model/rerank/core.py index ae27e7e85e..f844825d6c 100644 --- a/xinference/model/rerank/core.py +++ b/xinference/model/rerank/core.py @@ -15,7 +15,7 @@ import os from abc import abstractmethod from collections import defaultdict -from typing import Dict, List, Literal, Optional +from typing import Dict, List, Literal, Optional, Union from ..._compat import BaseModel from ...types import Rerank @@ -118,7 +118,7 @@ def __init__( @classmethod @abstractmethod - def check_lib(cls) -> bool: + def check_lib(cls) -> Union[bool, str]: pass @classmethod @@ -128,7 +128,7 @@ def match_json( model_family: RerankModelFamilyV2, model_spec: RerankSpecV1, quantization: str, - ) -> bool: + ) -> Union[bool, str]: pass @classmethod @@ -137,13 +137,15 @@ def match( model_family: RerankModelFamilyV2, model_spec: RerankSpecV1, quantization: str, - ): + ) -> bool: """ Return if the model_spec can be matched. """ - if not cls.check_lib(): + lib_result = cls.check_lib() + if lib_result != True: return False - return cls.match_json(model_family, model_spec, quantization) + match_result = cls.match_json(model_family, model_spec, quantization) + return match_result == True @staticmethod def _get_tokenizer(model_path): diff --git a/xinference/model/rerank/match_result.py b/xinference/model/rerank/match_result.py new file mode 100644 index 0000000000..1cd278aa5d --- /dev/null +++ b/xinference/model/rerank/match_result.py @@ -0,0 +1,77 @@ +""" +Error handling result structures for rerank model engine matching. + +This module provides structured error handling for engine matching operations, +allowing engines to provide detailed failure reasons and suggestions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +@dataclass +class MatchResult: + """ + Result of engine matching operation with detailed error information. + + This class provides structured information about whether an engine can handle + a specific model configuration, and if not, why and what alternatives exist. + """ + + is_match: bool + reason: Optional[str] = None + error_type: Optional[str] = None + technical_details: Optional[str] = None + + @classmethod + def success(cls) -> "MatchResult": + """Create a successful match result.""" + return cls(is_match=True) + + @classmethod + def failure( + cls, + reason: str, + error_type: Optional[str] = None, + technical_details: Optional[str] = None, + ) -> "MatchResult": + """Create a failed match result with optional details.""" + return cls( + is_match=False, + reason=reason, + error_type=error_type, + technical_details=technical_details, + ) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for API responses.""" + result: Dict[str, Any] = {"is_match": self.is_match} + if not self.is_match: + if self.reason: + result["reason"] = self.reason + if self.error_type: + result["error_type"] = self.error_type + if self.technical_details: + result["technical_details"] = self.technical_details + return result + + def to_error_string(self) -> str: + """Convert to error string for backward compatibility.""" + if self.is_match: + return "Available" + error_msg = self.reason or "Unknown error" + return error_msg + + +# Error type constants for better categorization +class ErrorType: + HARDWARE_REQUIREMENT = "hardware_requirement" + OS_REQUIREMENT = "os_requirement" + MODEL_FORMAT = "model_format" + DEPENDENCY_MISSING = "dependency_missing" + MODEL_COMPATIBILITY = "model_compatibility" + DIMENSION_MISMATCH = "dimension_mismatch" + VERSION_REQUIREMENT = "version_requirement" + CONFIGURATION_ERROR = "configuration_error" + ENGINE_UNAVAILABLE = "engine_unavailable" + RERANK_SPECIFIC = "rerank_specific" diff --git a/xinference/model/rerank/sentence_transformers/core.py b/xinference/model/rerank/sentence_transformers/core.py index fabbb6e593..eddc58ac06 100644 --- a/xinference/model/rerank/sentence_transformers/core.py +++ b/xinference/model/rerank/sentence_transformers/core.py @@ -16,7 +16,7 @@ import logging import threading import uuid -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, Union import numpy as np import torch @@ -191,7 +191,7 @@ def compute_logits(inputs, **kwargs): from FlagEmbedding import LayerWiseFlagLLMReranker as FlagReranker else: raise RuntimeError( - f"Unsupported Rank model type: {self.model_family.type}" + f"Unsupported Rerank model type: {self.model_family.type}" ) except ImportError: error_message = "Failed to import module 'FlagEmbedding'" @@ -331,8 +331,12 @@ def format_instruction(instruction, query, doc): return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata) @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("sentence_transformers") is not None + def check_lib(cls) -> Union[bool, str]: + return ( + True + if importlib.util.find_spec("sentence_transformers") is not None + else "sentence_transformers library is not installed" + ) @classmethod def match_json( @@ -340,6 +344,38 @@ def match_json( model_family: RerankModelFamilyV2, model_spec: RerankSpecV1, quantization: str, - ) -> bool: - # As default embedding engine, sentence-transformer support all models - return model_spec.model_format in ["pytorch"] + ) -> Union[bool, str]: + # Check library availability + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + + # Check model format compatibility + if model_spec.model_format not in ["pytorch"]: + return f"Sentence Transformers reranking only supports pytorch format, got: {model_spec.model_format}" + + # Check rerank-specific requirements + if not hasattr(model_family, "model_name"): + return "Rerank model family requires model name specification" + + # Check model type compatibility + if model_family.type and model_family.type not in [ + "rerank", + "unknown", + "cross-encoder", + "normal", + "LLM-based", + "LLM-based layerwise", + ]: + return f"Model type '{model_family.type}' may not be compatible with reranking engines" + + # Check max tokens limit for reranking performance + max_tokens = model_family.max_tokens + if max_tokens and max_tokens > 8192: # High token limits for reranking + return f"High max_tokens limit for reranking model: {max_tokens}, may cause performance issues" + + # Check language compatibility + if not model_family.language or len(model_family.language) == 0: + return "Rerank model language information is missing" + + return True diff --git a/xinference/model/rerank/vllm/core.py b/xinference/model/rerank/vllm/core.py index eac173b40c..9729a2ccc7 100644 --- a/xinference/model/rerank/vllm/core.py +++ b/xinference/model/rerank/vllm/core.py @@ -1,27 +1,58 @@ import importlib.util +import json +import logging import uuid -from typing import List, Optional +from typing import List, Optional, Union from ....types import Document, DocumentObj, Meta, Rerank, RerankTokens from ...utils import cache_clean from ..core import RerankModel, RerankModelFamilyV2, RerankSpecV1 -SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "Qwen3"] +logger = logging.getLogger(__name__) + +SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "qwen3"] class VLLMRerankModel(RerankModel): def load(self): try: + # Handle vLLM-transformers config conflict by setting environment variable + import os + + os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache_vllm" + from vllm import LLM - except ImportError: + except ImportError as e: error_message = "Failed to import module 'vllm'" installation_guide = [ "Please make sure 'vllm' is installed. ", "You can install it by `pip install vllm`\n", ] + # Check if it's a config conflict error + if "aimv2" in str(e): + error_message = ( + "vLLM has a configuration conflict with transformers library" + ) + installation_guide = [ + "This is a known issue with certain vLLM and transformers versions.", + "Try upgrading transformers or using a different vLLM version.\n", + ] + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + except Exception as e: + # Handle config registration conflicts + if "aimv2" in str(e) and "already used by a Transformers config" in str(e): + error_message = ( + "vLLM has a configuration conflict with transformers library" + ) + installation_guide = [ + "This is a known issue with certain vLLM and transformers versions.", + "Try: pip install --upgrade transformers vllm\n", + ] + raise RuntimeError(f"{error_message}\n\n{''.join(installation_guide)}") + raise if self.model_family.model_name in { "Qwen3-Reranker-0.6B", @@ -40,6 +71,42 @@ def load(self): classifier_from_token=["no", "yes"], is_original_qwen3_reranker=True, ) + elif isinstance(self._kwargs["hf_overrides"], str): + self._kwargs["hf_overrides"] = json.loads(self._kwargs["hf_overrides"]) + self._kwargs["hf_overrides"].update( + architectures=["Qwen3ForSequenceClassification"], + classifier_from_token=["no", "yes"], + is_original_qwen3_reranker=True, + ) + + # Set appropriate VLLM configuration parameters based on model capabilities + model_max_tokens = getattr(self.model_family, "max_tokens", 512) + + # Set max_model_len based on model family capabilities with reasonable limits + max_model_len = min(model_max_tokens, 8192) + if "max_model_len" not in self._kwargs: + self._kwargs["max_model_len"] = max_model_len + + # Ensure max_num_batched_tokens is sufficient for large models + if "max_num_batched_tokens" not in self._kwargs: + # max_num_batched_tokens should be at least max_model_len + # Set to a reasonable minimum that satisfies the constraint + self._kwargs["max_num_batched_tokens"] = max(4096, max_model_len) + + # Configure other reasonable defaults for reranking models + if "gpu_memory_utilization" not in self._kwargs: + self._kwargs["gpu_memory_utilization"] = 0.7 + + # Use a smaller block size for better compatibility + if "block_size" not in self._kwargs: + self._kwargs["block_size"] = 16 + + logger.debug( + f"VLLM configuration for rerank model {self.model_family.model_name}: " + f"max_model_len={self._kwargs.get('max_model_len')}, " + f"max_num_batched_tokens={self._kwargs.get('max_num_batched_tokens')}" + ) + self._model = LLM(model=self._model_path, task="score", **self._kwargs) self._tokenizer = self._model.get_tokenizer() @@ -139,8 +206,12 @@ def rerank( return Rerank(id=str(uuid.uuid4()), results=reranked_docs, meta=metadata) @classmethod - def check_lib(cls) -> bool: - return importlib.util.find_spec("vllm") is not None + def check_lib(cls) -> Union[bool, str]: + return ( + True + if importlib.util.find_spec("vllm") is not None + else "vllm library is not installed" + ) @classmethod def match_json( @@ -148,9 +219,62 @@ def match_json( model_family: RerankModelFamilyV2, model_spec: RerankSpecV1, quantization: str, - ) -> bool: - if model_spec.model_format in ["pytorch"]: - prefix = model_family.model_name.split("-", 1)[0] - if prefix in SUPPORTED_MODELS_PREFIXES: - return True - return False + ) -> Union[bool, str]: + # Check library availability + lib_result = cls.check_lib() + if lib_result != True: + return lib_result + + # Check model format compatibility + if model_spec.model_format not in ["pytorch"]: + return f"vLLM reranking only supports pytorch format, got: {model_spec.model_format}" + + # Check model name prefix matching + if model_spec.model_format == "pytorch": + try: + prefix = model_family.model_name.split("-", 1)[0].lower() + # Support both prefix matching and special cases + if prefix.lower() not in [p.lower() for p in SUPPORTED_MODELS_PREFIXES]: + # Special handling for Qwen3 models + if "qwen3" not in model_family.model_name.lower(): + return f"Model family prefix not supported by vLLM reranking: {prefix}" + except (IndexError, AttributeError): + return f"Unable to parse model family name for vLLM compatibility check: {model_family.model_name}" + + # Check rerank-specific requirements + if not hasattr(model_family, "model_name"): + return "Rerank model family requires model name specification for vLLM" + + # Check max tokens limit for vLLM reranking performance + max_tokens = model_family.max_tokens + if ( + max_tokens and max_tokens > 32768 + ): # vLLM has stricter limits, but Qwen3 can handle up to 32k + return f"Max tokens limit too high for vLLM reranking model: {max_tokens}, exceeds safe limit" + + # Additional runtime compatibility checks for vLLM version + try: + import vllm + from packaging.version import Version + + vllm_version = Version(vllm.__version__) + + # Check for vLLM version compatibility issues + if vllm_version >= Version("0.10.0") and vllm_version < Version("0.11.0"): + # vLLM 0.10.x has V1 engine issues on CPU + import platform + + if platform.system() == "Darwin" and platform.machine() in [ + "arm64", + "arm", + ]: + # Check if this is likely to run on CPU (most common for testing) + return f"vLLM {vllm_version} has compatibility issues with reranking models on Apple Silicon CPUs. Consider using a different platform or vLLM version." + elif vllm_version >= Version("0.11.0"): + # vLLM 0.11+ should have fixed the config conflict issue + pass + except Exception: + # If version check fails, continue with basic validation + pass + + return True diff --git a/xinference/model/rerank/vllm/tests/test_vllm.py b/xinference/model/rerank/vllm/tests/test_vllm.py index 37b948ac42..79ef529c22 100644 --- a/xinference/model/rerank/vllm/tests/test_vllm.py +++ b/xinference/model/rerank/vllm/tests/test_vllm.py @@ -3,10 +3,23 @@ import pytest from .....client import Client + +# Ensure rerank engines are properly initialized +# This addresses potential CI environment initialization issues +from ... import BUILTIN_RERANK_MODELS from ...cache_manager import RerankCacheManager from ...core import RerankModelFamilyV2, RerankSpecV1 from ..core import VLLMRerankModel +# Force import of the entire rerank module to ensure initialization + + +if "bge-reranker-base" in BUILTIN_RERANK_MODELS: + # Force regeneration of engine configuration + from ... import generate_engine_config_by_model_name + + generate_engine_config_by_model_name(BUILTIN_RERANK_MODELS["bge-reranker-base"][0]) + TEST_MODEL_SPEC = RerankModelFamilyV2( version=2, model_name="bge-reranker-base", @@ -61,6 +74,7 @@ def test_qwen3_vllm(setup): model_name="Qwen3-Reranker-0.6B", model_type="rerank", model_engine="vllm", + max_num_batched_tokens=81920, # Allow larger batch size for Qwen3 ) model = client.get_model(model_uid) diff --git a/xinference/model/utils.py b/xinference/model/utils.py index eff9b84ce3..6887d00b15 100644 --- a/xinference/model/utils.py +++ b/xinference/model/utils.py @@ -472,44 +472,533 @@ def __exit__(self, exc_type, exc_val, exc_tb): def get_engine_params_by_name( model_type: Optional[str], model_name: str -) -> Optional[Dict[str, List[dict]]]: +) -> Optional[Dict[str, Union[List[Dict[str, Any]], str]]]: + engine_params: Dict[str, Union[List[Dict[str, Any]], str]] = {} + if model_type == "LLM": - from .llm.llm_family import LLM_ENGINES + from .llm.llm_family import LLM_ENGINES, SUPPORTED_ENGINES if model_name not in LLM_ENGINES: return None - # filter llm_class - engine_params = deepcopy(LLM_ENGINES[model_name]) - for engine, params in engine_params.items(): + # Get all supported engines, not just currently available ones + all_supported_engines = list(SUPPORTED_ENGINES.keys()) + + # First add currently available engine parameters + available_engines = deepcopy(LLM_ENGINES[model_name]) + for engine, params in available_engines.items(): for param in params: - del param["llm_class"] + # Remove previous available attribute as available engines don't need this flag + if "available" in param: + del param["available"] + engine_params[engine] = params + + # Check unavailable engines with detailed error information + for engine_name in all_supported_engines: + if engine_name not in engine_params: # Engine not in available list + try: + llm_engine_classes = SUPPORTED_ENGINES[engine_name] + + # Try to get detailed error information from engine's match_with_reason + detailed_error = None + + # We need a sample model to test against, use the first available spec + if model_name in LLM_ENGINES and LLM_ENGINES[model_name]: + # Try to get model family for testing + try: + pass + + # Get the full model family instead of a single spec + from .llm.llm_family import BUILTIN_LLM_FAMILIES + + llm_family = None + for family in BUILTIN_LLM_FAMILIES: + if model_name == family.model_name: + llm_family = family + break + + if llm_family and llm_family.model_specs: + + # Test each engine class for detailed error info + for engine_class in llm_engine_classes: + try: + engine_compatible = False + error_details = None + + # Try each model spec to find one compatible with this engine + for llm_spec in llm_family.model_specs: + quantization = ( + llm_spec.quantization or "none" + ) + + if hasattr(engine_class, "match_json"): + match_result = engine_class.match_json( + llm_family, llm_spec, quantization + ) + if match_result == True: + engine_compatible = True + break # Found compatible spec + else: + # Save error details, but continue trying other specs + error_details = { + "error": ( + match_result + if isinstance( + match_result, str + ) + else "Engine is not compatible" + ), + "error_type": "model_compatibility", + "technical_details": f"The {engine_class.__name__} engine cannot handle the current model configuration: {llm_spec.model_format} format", + } + + if not engine_compatible and error_details: + detailed_error = error_details + break + except Exception as e: + # Fall back to next engine class with clear error logging + logger.warning( + f"Engine class {engine_class.__name__} match_json failed: {e}" + ) + # Continue to try next engine class, but this is expected behavior for fallback + continue + except Exception as e: + # If we can't get model family, fail with clear error + logger.error( + f"Failed to get model family for {model_name} (LLM): {e}" + ) + raise RuntimeError( + f"Unable to process LLM model {model_name}: {e}" + ) + + if detailed_error: + # Return only the error message without engine_name prefix (key already contains engine name) + engine_params[engine_name] = ( + detailed_error.get("error") or "Unknown error" + ) + else: + # Fallback to basic error checking for backward compatibility + for engine_class in llm_engine_classes: + try: + if hasattr(engine_class, "check_lib"): + lib_result = engine_class.check_lib() + if lib_result != True: + # If check_lib returns a string, it's an error message + error_msg = ( + lib_result + if isinstance(lib_result, str) + else f"Engine {engine_name} library check failed" + ) + engine_params[engine_name] = error_msg + break + else: + # If no check_lib method, try to use engine's match method for compatibility check + # This provides more detailed and accurate error information + try: + # Create a minimal test spec if we don't have real model specs + from .llm.llm_family import ( + LlamaCppLLMSpecV2, + LLMFamilyV2, + MLXLLMSpecV2, + PytorchLLMSpecV2, + ) + + # Create appropriate test spec based on engine class + engine_name_lower = ( + engine_class.__name__.lower() + ) + if "mlx" in engine_name_lower: + # MLX engines need MLX format + test_spec_class = MLXLLMSpecV2 + model_format = "mlx" + elif ( + "ggml" in engine_name_lower + or "llamacpp" in engine_name_lower + ): + # GGML/llama.cpp engines need GGML format + test_spec_class = LlamaCppLLMSpecV2 + model_format = "ggufv2" + else: + # Default to PyTorch format (supports gptq, awq, fp8, bnb) + test_spec_class = PytorchLLMSpecV2 + model_format = "pytorch" + + # Create a minimal test case with appropriate format + test_family = LLMFamilyV2( + model_name="test", + model_family="test", + model_specs=[ + test_spec_class( + model_format=model_format, + quantization="none", + ) + ], + ) + test_spec = test_family.model_specs[0] + + # Use the engine's match method if available + if hasattr(engine_class, "match_with_reason"): + result = engine_class.match_with_reason( + test_family, test_spec, "none" + ) + if result.is_match: + break # Engine is available + else: + # Return only the error message without engine_name prefix (key already contains engine name) + engine_params[engine_name] = ( + result.reason + or "Unknown compatibility error" + ) + break + elif hasattr(engine_class, "match_json"): + # Fallback to simple match method - use test data + match_result = engine_class.match_json( + test_family, test_spec, "none" + ) + if match_result == True: + break # Engine is available + else: + # Get detailed error information + error_message = ( + match_result + if isinstance(match_result, str) + else f"Engine {engine_name} is not compatible with current model or environment" + ) + engine_params[engine_name] = ( + error_message + ) + break + else: + # Final fallback: generic import check + raise ImportError( + "No compatibility check method available" + ) + + except ImportError as e: + engine_params[engine_name] = ( + f"Engine {engine_name} library is not installed: {str(e)}" + ) + break + except Exception as e: + engine_params[engine_name] = ( + f"Engine {engine_name} is not available: {str(e)}" + ) + break + except ImportError as e: + engine_params[engine_name] = ( + f"Engine {engine_name} library is not installed: {str(e)}" + ) + break + except Exception as e: + engine_params[engine_name] = ( + f"Engine {engine_name} is not available: {str(e)}" + ) + break + + # Only set default error if not already set by one of the exception handlers + if engine_name not in engine_params: + engine_params[engine_name] = ( + f"Engine {engine_name} is not compatible with current model or environment" + ) + + except Exception as e: + # If exception occurs during checking, return simple string format + engine_params[engine_name] = ( + f"Error checking engine {engine_name}: {str(e)}" + ) + + # Filter out llm_class field + for engine in engine_params.keys(): + if isinstance( + engine_params[engine], list + ): # Only process parameter lists of available engines + for param in engine_params[engine]: # type: ignore + if isinstance(param, dict) and "llm_class" in param: + del param["llm_class"] return engine_params elif model_type == "embedding": - from .embedding.embed_family import EMBEDDING_ENGINES + from .embedding.embed_family import ( + EMBEDDING_ENGINES, + ) + from .embedding.embed_family import ( + SUPPORTED_ENGINES as EMBEDDING_SUPPORTED_ENGINES, + ) if model_name not in EMBEDDING_ENGINES: return None - # filter embedding_class - engine_params = deepcopy(EMBEDDING_ENGINES[model_name]) - for engine, params in engine_params.items(): + # Get all supported engines, not just currently available ones + all_supported_engines = list(EMBEDDING_SUPPORTED_ENGINES.keys()) + + # First add currently available engine parameters + available_engines = deepcopy(EMBEDDING_ENGINES[model_name]) + for engine, params in available_engines.items(): for param in params: - del param["embedding_class"] + # Remove previous available attribute as available engines don't need this flag + if "available" in param: + del param["available"] + engine_params[engine] = params + + # Check unavailable engines + for engine_name in all_supported_engines: + if engine_name not in engine_params: # Engine not in available list + try: + embedding_engine_classes = EMBEDDING_SUPPORTED_ENGINES[engine_name] + embedding_error_details: Optional[Dict[str, str]] = None + + # Try to find specific error reasons + for embedding_engine_class in embedding_engine_classes: + try: + if hasattr(embedding_engine_class, "check_lib"): + embedding_lib_available: bool = embedding_engine_class.check_lib() # type: ignore[assignment] + if not embedding_lib_available: + embedding_error_details = { + "error": f"Engine {engine_name} library is not available", + "error_type": "dependency_missing", + "technical_details": f"The required library for {engine_name} engine is not installed or not accessible", + } + break + else: + # If no check_lib method, try to use engine's match method for compatibility check + try: + from .embedding.core import ( + EmbeddingModelFamilyV2, + TransformersEmbeddingSpecV1, + ) + + # Use the engine's match method if available + if hasattr(embedding_engine_class, "match"): + # Create a minimal test case + test_family = EmbeddingModelFamilyV2( + model_name="test", + model_specs=[ + TransformersEmbeddingSpecV1( + model_format="pytorch", + quantization="none", + ) + ], + ) + test_spec = test_family.model_specs[0] + + # Use the engine's match_json method to check compatibility and get detailed error + match_result = ( + embedding_engine_class.match_json( + test_family, test_spec, "none" + ) + ) + if match_result == True: + break # Engine is available + else: + # Get detailed error information + error_message = ( + match_result + if isinstance(match_result, str) + else f"Engine {engine_name} is not compatible with current model or environment" + ) + embedding_error_details = { + "error": error_message, + "error_type": "model_compatibility", + "technical_details": f"The {engine_name} engine cannot handle the current embedding model configuration", + } + break + else: + # Final fallback: generic import check + raise ImportError( + "No compatibility check method available" + ) + + except ImportError as e: + embedding_error_details = { + "error": f"Engine {engine_name} library is not installed: {str(e)}", + "error_type": "dependency_missing", + "technical_details": f"Missing required dependency for {engine_name} engine: {str(e)}", + } + except Exception as e: + embedding_error_details = { + "error": f"Engine {engine_name} is not available: {str(e)}", + "error_type": "configuration_error", + "technical_details": f"Configuration or environment issue preventing {engine_name} engine from working: {str(e)}", + } + break + except ImportError as e: + embedding_error_details = { + "error": f"Engine {engine_name} library is not installed: {str(e)}", + "error_type": "dependency_missing", + "technical_details": f"Missing required dependency for {engine_name} engine: {str(e)}", + } + except Exception as e: + embedding_error_details = { + "error": f"Engine {engine_name} is not available: {str(e)}", + "error_type": "configuration_error", + "technical_details": f"Configuration or environment issue preventing {engine_name} engine from working: {str(e)}", + } + + if embedding_error_details is None: + embedding_error_details = { + "error": f"Engine {engine_name} is not compatible with current model or environment", + "error_type": "model_compatibility", + "technical_details": f"The {engine_name} engine cannot handle the current embedding model configuration", + } + + # For unavailable engines, return simple string format + engine_params[engine_name] = ( + embedding_error_details.get("error") or "Unknown error" + ) + + except Exception as e: + # If exception occurs during checking, return simple string format + engine_params[engine_name] = ( + f"Error checking engine {engine_name}: {str(e)}" + ) + + # Filter out embedding_class field + for engine in engine_params.keys(): + if isinstance( + engine_params[engine], list + ): # Only process parameter lists of available engines + for param in engine_params[engine]: # type: ignore + if isinstance(param, dict) and "embedding_class" in param: + del param["embedding_class"] return engine_params elif model_type == "rerank": from .rerank.rerank_family import RERANK_ENGINES + from .rerank.rerank_family import SUPPORTED_ENGINES as RERANK_SUPPORTED_ENGINES if model_name not in RERANK_ENGINES: return None - # filter rerank_class - engine_params = deepcopy(RERANK_ENGINES[model_name]) - for engine, params in engine_params.items(): + # Get all supported engines, not just currently available ones + all_supported_engines = list(RERANK_SUPPORTED_ENGINES.keys()) + + # First add currently available engine parameters + available_engines = deepcopy(RERANK_ENGINES[model_name]) + for engine, params in available_engines.items(): for param in params: - del param["rerank_class"] + # Remove previous available attribute as available engines don't need this flag + if "available" in param: + del param["available"] + engine_params[engine] = params + + # Check unavailable engines + for engine_name in all_supported_engines: + if engine_name not in engine_params: # Engine not in available list + try: + rerank_engine_classes = RERANK_SUPPORTED_ENGINES[engine_name] + rerank_error_details: Optional[Dict[str, str]] = None + + # Try to find specific error reasons + for rerank_engine_class in rerank_engine_classes: + try: + if hasattr(rerank_engine_class, "check_lib"): + rerank_lib_available: bool = rerank_engine_class.check_lib() # type: ignore[assignment] + if not rerank_lib_available: + rerank_error_details = { + "error": f"Engine {engine_name} library is not available", + "error_type": "dependency_missing", + "technical_details": f"The required library for {engine_name} engine is not installed or not accessible", + } + break + else: + # If no check_lib method, try to use engine's match method for compatibility check + try: + from .rerank.core import ( + RerankModelFamilyV2, + RerankSpecV1, + ) + + # Use the engine's match method if available + if hasattr(rerank_engine_class, "match"): + # Create a minimal test case + test_family = RerankModelFamilyV2( + model_name="test", + model_specs=[ + RerankSpecV1( + model_format="pytorch", + quantization="none", + ) + ], + ) + test_spec = test_family.model_specs[0] + + # Use the engine's match_json method to check compatibility and get detailed error + match_result = rerank_engine_class.match_json( + test_family, test_spec, "none" + ) + if match_result == True: + break # Engine is available + else: + # Get detailed error information + error_message = ( + match_result + if isinstance(match_result, str) + else f"Engine {engine_name} is not compatible with current model or environment" + ) + rerank_error_details = { + "error": error_message, + "error_type": "model_compatibility", + "technical_details": f"The {engine_name} engine cannot handle the current rerank model configuration", + } + break + else: + # Final fallback: generic import check + raise ImportError( + "No compatibility check method available" + ) + + except ImportError as e: + rerank_error_details = { + "error": f"Engine {engine_name} library is not installed: {str(e)}", + "error_type": "dependency_missing", + "technical_details": f"Missing required dependency for {engine_name} engine: {str(e)}", + } + except Exception as e: + rerank_error_details = { + "error": f"Engine {engine_name} is not available: {str(e)}", + "error_type": "configuration_error", + "technical_details": f"Configuration or environment issue preventing {engine_name} engine from working: {str(e)}", + } + break + except ImportError as e: + rerank_error_details = { + "error": f"Engine {engine_name} library is not installed: {str(e)}", + "error_type": "dependency_missing", + "technical_details": f"Missing required dependency for {engine_name} engine: {str(e)}", + } + except Exception as e: + rerank_error_details = { + "error": f"Engine {engine_name} is not available: {str(e)}", + "error_type": "configuration_error", + "technical_details": f"Configuration or environment issue preventing {engine_name} engine from working: {str(e)}", + } + + if rerank_error_details is None: + rerank_error_details = { + "error": f"Engine {engine_name} is not compatible with current model or environment", + "error_type": "model_compatibility", + "technical_details": f"The {engine_name} engine cannot handle the current rerank model configuration", + } + + # For unavailable engines, return simple string format + engine_params[engine_name] = ( + rerank_error_details.get("error") or "Unknown error" + ) + + except Exception as e: + # If exception occurs during checking, return simple string format + engine_params[engine_name] = ( + f"Error checking engine {engine_name}: {str(e)}" + ) + + # Filter out rerank_class field + for engine in engine_params.keys(): + if isinstance( + engine_params[engine], list + ): # Only process parameter lists of available engines + for param in engine_params[engine]: # type: ignore + if isinstance(param, dict) and "rerank_class" in param: + del param["rerank_class"] return engine_params else: diff --git a/xinference/ui/web/ui/src/scenes/launch_model/components/launchModelDrawer.js b/xinference/ui/web/ui/src/scenes/launch_model/components/launchModelDrawer.js index 7a5bda45e8..173a0dc2a1 100644 --- a/xinference/ui/web/ui/src/scenes/launch_model/components/launchModelDrawer.js +++ b/xinference/ui/web/ui/src/scenes/launch_model/components/launchModelDrawer.js @@ -13,15 +13,11 @@ import { CircularProgress, Collapse, Drawer, - FormControl, FormControlLabel, - InputLabel, ListItemButton, ListItemText, - MenuItem, Radio, RadioGroup, - Select, Switch, TextField, Tooltip, @@ -39,45 +35,11 @@ import DynamicFieldList from './dynamicFieldList' import getModelFormConfig from './modelFormConfig' import PasteDialog from './pasteDialog' import Progress from './progress' +import SelectField from './selectField' const enginesWithNWorker = ['SGLang', 'vLLM', 'MLX'] const modelEngineType = ['LLM', 'embedding', 'rerank'] -const SelectField = ({ - label, - labelId, - name, - value, - onChange, - options = [], - disabled = false, - required = false, -}) => ( - - {label} - - -) - const LaunchModelDrawer = ({ modelData, modelType, @@ -549,19 +511,32 @@ const LaunchModelDrawer = ({ const engineItems = useMemo(() => { return engineOptions.map((engine) => { - const modelFormats = Array.from( - new Set(enginesObj[engine]?.map((item) => item.model_format)) - ) + const engineData = enginesObj[engine] + let modelFormats = [] + let label = engine + let disabled = false + + if (Array.isArray(engineData)) { + modelFormats = Array.from( + new Set(engineData.map((item) => item.model_format)) + ) - const relevantSpecs = modelData.model_specs.filter((spec) => - modelFormats.includes(spec.model_format) - ) + const relevantSpecs = modelData.model_specs.filter((spec) => + modelFormats.includes(spec.model_format) + ) + + const cached = relevantSpecs.some((spec) => isCached(spec)) - const cached = relevantSpecs.some((spec) => isCached(spec)) + label = cached ? `${engine} ${t('launchModel.cached')}` : engine + } else if (typeof engineData === 'string') { + label = `${engine} (${engineData})` + disabled = true + } return { value: engine, - label: cached ? `${engine} ${t('launchModel.cached')}` : engine, + label, + disabled, } }) }, [engineOptions, enginesObj, modelData]) diff --git a/xinference/ui/web/ui/src/scenes/launch_model/components/selectField.js b/xinference/ui/web/ui/src/scenes/launch_model/components/selectField.js new file mode 100644 index 0000000000..7e9a4af8ce --- /dev/null +++ b/xinference/ui/web/ui/src/scenes/launch_model/components/selectField.js @@ -0,0 +1,42 @@ +import { FormControl, InputLabel, MenuItem, Select } from '@mui/material' + +const SelectField = ({ + label, + labelId, + name, + value, + onChange, + options = [], + disabled = false, + required = false, +}) => ( + + {label} + + +) + +export default SelectField