Skip to content

Commit 7690f2c

Browse files
committed
[ENH]: Client side retries
1 parent b614a69 commit 7690f2c

File tree

1 file changed

+57
-14
lines changed

1 file changed

+57
-14
lines changed

chromadb/api/fastapi.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@
66
import httpx
77
import urllib.parse
88
from overrides import override
9+
from tenacity import (
10+
retry,
11+
stop_after_attempt,
12+
wait_exponential,
13+
retry_if_exception_type,
14+
before_sleep_log,
15+
RetryError
16+
)
917

1018
from chromadb.api.collection_configuration import (
1119
CreateCollectionConfiguration,
@@ -58,6 +66,23 @@
5866

5967
logger = logging.getLogger(__name__)
6068

69+
def is_retryable_exception(exception: BaseException) -> bool:
70+
if isinstance(exception, (
71+
httpx.ConnectError,
72+
httpx.ConnectTimeout,
73+
httpx.ReadTimeout,
74+
httpx.WriteTimeout,
75+
httpx.PoolTimeout,
76+
httpx.NetworkError,
77+
httpx.RemoteProtocolError,
78+
)):
79+
return True
80+
81+
if isinstance(exception, httpx.HTTPStatusError):
82+
# Retry on server errors that might be temporary
83+
return exception.response.status_code in [502, 503, 504]
84+
85+
return False
6186

6287
class FastAPI(BaseHTTPClient, ServerAPI):
6388
def __init__(self, system: System):
@@ -99,20 +124,38 @@ def __init__(self, system: System):
99124
self._session.headers[header] = value.get_secret_value()
100125

101126
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
102-
# If the request has json in kwargs, use orjson to serialize it,
103-
# remove it from kwargs, and add it to the content parameter
104-
# This is because httpx uses a slower json serializer
105-
if "json" in kwargs:
106-
data = orjson.dumps(kwargs.pop("json"))
107-
kwargs["content"] = data
108-
109-
# Unlike requests, httpx does not automatically escape the path
110-
escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None)
111-
url = self._api_url + escaped_path
112-
113-
response = self._session.request(method, url, **cast(Any, kwargs))
114-
BaseHTTPClient._raise_chroma_error(response)
115-
return orjson.loads(response.text)
127+
@retry(
128+
stop=stop_after_attempt(3),
129+
wait=wait_exponential(
130+
multiplier=2,
131+
min=1,
132+
max=60
133+
),
134+
retry=retry_if_exception_type(is_retryable_exception),
135+
before_sleep=before_sleep_log(logger, logging.INFO),
136+
reraise=True
137+
)
138+
def _request_with_retry():
139+
# If the request has json in kwargs, use orjson to serialize it,
140+
# remove it from kwargs, and add it to the content parameter
141+
# This is because httpx uses a slower json serializer
142+
if "json" in kwargs:
143+
data = orjson.dumps(kwargs.pop("json"))
144+
kwargs["content"] = data
145+
146+
# Unlike requests, httpx does not automatically escape the path
147+
escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None)
148+
url = self._api_url + escaped_path
149+
150+
response = self._session.request(method, url, **cast(Any, kwargs))
151+
BaseHTTPClient._raise_chroma_error(response)
152+
return orjson.loads(response.text)
153+
154+
try:
155+
return _request_with_retry()
156+
except RetryError as e:
157+
# Re-raise the last exception that caused the retry to fail
158+
raise e.last_attempt.exception() from None
116159

117160
@trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
118161
@override

0 commit comments

Comments
 (0)