|
6 | 6 | import httpx
|
7 | 7 | import urllib.parse
|
8 | 8 | from overrides import override
|
| 9 | +from tenacity import ( |
| 10 | + RetryError, |
| 11 | + Retrying, |
| 12 | + before_sleep_log, |
| 13 | + retry_if_exception, |
| 14 | + stop_after_attempt, |
| 15 | + wait_exponential, |
| 16 | + wait_random_exponential, |
| 17 | +) |
9 | 18 |
|
10 | 19 | from chromadb.api.collection_configuration import (
|
11 | 20 | CreateCollectionConfiguration,
|
|
57 | 66 | logger = logging.getLogger(__name__)
|
58 | 67 |
|
59 | 68 |
|
| 69 | +def is_retryable_exception(exception: BaseException) -> bool: |
| 70 | + if isinstance( |
| 71 | + exception, |
| 72 | + ( |
| 73 | + httpx.ConnectError, |
| 74 | + httpx.ConnectTimeout, |
| 75 | + httpx.ReadTimeout, |
| 76 | + httpx.WriteTimeout, |
| 77 | + httpx.PoolTimeout, |
| 78 | + httpx.NetworkError, |
| 79 | + httpx.RemoteProtocolError, |
| 80 | + ), |
| 81 | + ): |
| 82 | + return True |
| 83 | + |
| 84 | + if isinstance(exception, httpx.HTTPStatusError): |
| 85 | + # Retry on server errors that might be temporary |
| 86 | + return exception.response.status_code in [502, 503, 504] |
| 87 | + |
| 88 | + return False |
| 89 | + |
| 90 | + |
60 | 91 | class FastAPI(BaseHTTPClient, ServerAPI):
|
61 | 92 | def __init__(self, system: System):
|
62 | 93 | super().__init__(system)
|
@@ -97,20 +128,62 @@ def __init__(self, system: System):
|
97 | 128 | self._session.headers[header] = value.get_secret_value()
|
98 | 129 |
|
99 | 130 | def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
|
100 |
| - # If the request has json in kwargs, use orjson to serialize it, |
101 |
| - # remove it from kwargs, and add it to the content parameter |
102 |
| - # This is because httpx uses a slower json serializer |
103 |
| - if "json" in kwargs: |
104 |
| - data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY) |
105 |
| - kwargs["content"] = data |
106 |
| - |
107 |
| - # Unlike requests, httpx does not automatically escape the path |
108 |
| - escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None) |
109 |
| - url = self._api_url + escaped_path |
110 |
| - |
111 |
| - response = self._session.request(method, url, **cast(Any, kwargs)) |
112 |
| - BaseHTTPClient._raise_chroma_error(response) |
113 |
| - return orjson.loads(response.text) |
| 131 | + def _send_request() -> Any: |
| 132 | + # If the request has json in kwargs, use orjson to serialize it, |
| 133 | + # remove it from kwargs, and add it to the content parameter |
| 134 | + # This is because httpx uses a slower json serializer |
| 135 | + if "json" in kwargs: |
| 136 | + data = orjson.dumps( |
| 137 | + kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY |
| 138 | + ) |
| 139 | + kwargs["content"] = data |
| 140 | + |
| 141 | + # Unlike requests, httpx does not automatically escape the path |
| 142 | + escaped_path = urllib.parse.quote( |
| 143 | + path, safe="/", encoding=None, errors=None |
| 144 | + ) |
| 145 | + url = self._api_url + escaped_path |
| 146 | + |
| 147 | + response = self._session.request(method, url, **cast(Any, kwargs)) |
| 148 | + BaseHTTPClient._raise_chroma_error(response) |
| 149 | + return orjson.loads(response.text) |
| 150 | + |
| 151 | + retry_config = self._settings.retry_config |
| 152 | + |
| 153 | + if retry_config is None: |
| 154 | + return _send_request() |
| 155 | + |
| 156 | + min_delay = max(float(retry_config.min_delay), 0.0) |
| 157 | + max_delay = max(float(retry_config.max_delay), min_delay) |
| 158 | + multiplier = max(min_delay, 1e-3) |
| 159 | + exp_base = retry_config.factor if retry_config.factor > 0 else 2.0 |
| 160 | + |
| 161 | + wait_args = { |
| 162 | + "multiplier": multiplier, |
| 163 | + "min": min_delay, |
| 164 | + "max": max_delay, |
| 165 | + "exp_base": exp_base, |
| 166 | + } |
| 167 | + |
| 168 | + wait_strategy = ( |
| 169 | + wait_random_exponential(**wait_args) |
| 170 | + if retry_config.jitter |
| 171 | + else wait_exponential(**wait_args) |
| 172 | + ) |
| 173 | + |
| 174 | + retrying = Retrying( |
| 175 | + stop=stop_after_attempt(retry_config.max_attempts), |
| 176 | + wait=wait_strategy, |
| 177 | + retry=retry_if_exception(is_retryable_exception), |
| 178 | + before_sleep=before_sleep_log(logger, logging.INFO), |
| 179 | + reraise=True, |
| 180 | + ) |
| 181 | + |
| 182 | + try: |
| 183 | + return retrying(_send_request) |
| 184 | + except RetryError as e: |
| 185 | + # Re-raise the last exception that caused the retry to fail |
| 186 | + raise e.last_attempt.exception() from None |
114 | 187 |
|
115 | 188 | @trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
|
116 | 189 | @override
|
|
0 commit comments