|
6 | 6 | import httpx
|
7 | 7 | import urllib.parse
|
8 | 8 | from overrides import override
|
| 9 | +from tenacity import ( |
| 10 | + retry, |
| 11 | + stop_after_attempt, |
| 12 | + wait_exponential, |
| 13 | + retry_if_exception_type, |
| 14 | + before_sleep_log, |
| 15 | + RetryError |
| 16 | +) |
9 | 17 |
|
10 | 18 | from chromadb.api.collection_configuration import (
|
11 | 19 | CreateCollectionConfiguration,
|
|
56 | 64 |
|
57 | 65 | logger = logging.getLogger(__name__)
|
58 | 66 |
|
| 67 | +def is_retryable_exception(exception: BaseException) -> bool: |
| 68 | + if isinstance(exception, ( |
| 69 | + httpx.ConnectError, |
| 70 | + httpx.ConnectTimeout, |
| 71 | + httpx.ReadTimeout, |
| 72 | + httpx.WriteTimeout, |
| 73 | + httpx.PoolTimeout, |
| 74 | + httpx.NetworkError, |
| 75 | + httpx.RemoteProtocolError, |
| 76 | + )): |
| 77 | + return True |
| 78 | + |
| 79 | + if isinstance(exception, httpx.HTTPStatusError): |
| 80 | + # Retry on server errors that might be temporary |
| 81 | + return exception.response.status_code in [502, 503, 504] |
| 82 | + |
| 83 | + return False |
59 | 84 |
|
60 | 85 | class FastAPI(BaseHTTPClient, ServerAPI):
|
61 | 86 | def __init__(self, system: System):
|
@@ -97,20 +122,38 @@ def __init__(self, system: System):
|
97 | 122 | self._session.headers[header] = value.get_secret_value()
|
98 | 123 |
|
99 | 124 | def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
|
100 |
| - # If the request has json in kwargs, use orjson to serialize it, |
101 |
| - # remove it from kwargs, and add it to the content parameter |
102 |
| - # This is because httpx uses a slower json serializer |
103 |
| - if "json" in kwargs: |
104 |
| - data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY) |
105 |
| - kwargs["content"] = data |
106 |
| - |
107 |
| - # Unlike requests, httpx does not automatically escape the path |
108 |
| - escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None) |
109 |
| - url = self._api_url + escaped_path |
110 |
| - |
111 |
| - response = self._session.request(method, url, **cast(Any, kwargs)) |
112 |
| - BaseHTTPClient._raise_chroma_error(response) |
113 |
| - return orjson.loads(response.text) |
| 125 | + @retry( |
| 126 | + stop=stop_after_attempt(3), |
| 127 | + wait=wait_exponential( |
| 128 | + multiplier=2, |
| 129 | + min=1, |
| 130 | + max=60 |
| 131 | + ), |
| 132 | + retry=retry_if_exception_type(is_retryable_exception), |
| 133 | + before_sleep=before_sleep_log(logger, logging.INFO), |
| 134 | + reraise=True |
| 135 | + ) |
| 136 | + def _request_with_retry(): |
| 137 | + # If the request has json in kwargs, use orjson to serialize it, |
| 138 | + # remove it from kwargs, and add it to the content parameter |
| 139 | + # This is because httpx uses a slower json serializer |
| 140 | + if "json" in kwargs: |
| 141 | + data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY) |
| 142 | + kwargs["content"] = data |
| 143 | + |
| 144 | + # Unlike requests, httpx does not automatically escape the path |
| 145 | + escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None) |
| 146 | + url = self._api_url + escaped_path |
| 147 | + |
| 148 | + response = self._session.request(method, url, **cast(Any, kwargs)) |
| 149 | + BaseHTTPClient._raise_chroma_error(response) |
| 150 | + return orjson.loads(response.text) |
| 151 | + |
| 152 | + try: |
| 153 | + return _request_with_retry() |
| 154 | + except RetryError as e: |
| 155 | + # Re-raise the last exception that caused the retry to fail |
| 156 | + raise e.last_attempt.exception() from None |
114 | 157 |
|
115 | 158 | @trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
|
116 | 159 | @override
|
|
0 commit comments