|
1 | 1 | import importlib.util |
| 2 | +import json |
| 3 | +import logging |
2 | 4 | import uuid |
3 | 5 | from typing import List, Optional, Union |
4 | 6 |
|
5 | 7 | from ....types import Document, DocumentObj, Meta, Rerank, RerankTokens |
6 | 8 | from ...utils import cache_clean |
7 | 9 | from ..core import RerankModel, RerankModelFamilyV2, RerankSpecV1 |
8 | 10 |
|
| 11 | +logger = logging.getLogger(__name__) |
| 12 | + |
9 | 13 | SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "qwen3"] |
10 | 14 |
|
11 | 15 |
|
@@ -67,6 +71,42 @@ def load(self): |
67 | 71 | classifier_from_token=["no", "yes"], |
68 | 72 | is_original_qwen3_reranker=True, |
69 | 73 | ) |
| 74 | + elif isinstance(self._kwargs["hf_overrides"], str): |
| 75 | + self._kwargs["hf_overrides"] = json.loads(self._kwargs["hf_overrides"]) |
| 76 | + self._kwargs["hf_overrides"].update( |
| 77 | + architectures=["Qwen3ForSequenceClassification"], |
| 78 | + classifier_from_token=["no", "yes"], |
| 79 | + is_original_qwen3_reranker=True, |
| 80 | + ) |
| 81 | + |
| 82 | + # Set appropriate VLLM configuration parameters based on model capabilities |
| 83 | + model_max_tokens = getattr(self.model_family, "max_tokens", 512) |
| 84 | + |
| 85 | + # Set max_model_len based on model family capabilities with reasonable limits |
| 86 | + max_model_len = min(model_max_tokens, 8192) |
| 87 | + if "max_model_len" not in self._kwargs: |
| 88 | + self._kwargs["max_model_len"] = max_model_len |
| 89 | + |
| 90 | + # Ensure max_num_batched_tokens is sufficient for large models |
| 91 | + if "max_num_batched_tokens" not in self._kwargs: |
| 92 | + # max_num_batched_tokens should be at least max_model_len |
| 93 | + # Set to a reasonable minimum that satisfies the constraint |
| 94 | + self._kwargs["max_num_batched_tokens"] = max(4096, max_model_len) |
| 95 | + |
| 96 | + # Configure other reasonable defaults for reranking models |
| 97 | + if "gpu_memory_utilization" not in self._kwargs: |
| 98 | + self._kwargs["gpu_memory_utilization"] = 0.7 |
| 99 | + |
| 100 | + # Use a smaller block size for better compatibility |
| 101 | + if "block_size" not in self._kwargs: |
| 102 | + self._kwargs["block_size"] = 16 |
| 103 | + |
| 104 | + logger.debug( |
| 105 | + f"VLLM configuration for rerank model {self.model_family.model_name}: " |
| 106 | + f"max_model_len={self._kwargs.get('max_model_len')}, " |
| 107 | + f"max_num_batched_tokens={self._kwargs.get('max_num_batched_tokens')}" |
| 108 | + ) |
| 109 | + |
70 | 110 | self._model = LLM(model=self._model_path, task="score", **self._kwargs) |
71 | 111 | self._tokenizer = self._model.get_tokenizer() |
72 | 112 |
|
|
0 commit comments