Skip to content

Commit ef8c70f

Browse files
sanketkediajairad26
authored andcommitted
[ENH]: Client side retries
1 parent b381e79 commit ef8c70f

File tree

1 file changed

+57
-14
lines changed

1 file changed

+57
-14
lines changed

chromadb/api/fastapi.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@
66
import httpx
77
import urllib.parse
88
from overrides import override
9+
from tenacity import (
10+
retry,
11+
stop_after_attempt,
12+
wait_exponential,
13+
retry_if_exception_type,
14+
before_sleep_log,
15+
RetryError
16+
)
917

1018
from chromadb.api.collection_configuration import (
1119
CreateCollectionConfiguration,
@@ -56,6 +64,23 @@
5664

5765
logger = logging.getLogger(__name__)
5866

67+
def is_retryable_exception(exception: BaseException) -> bool:
68+
if isinstance(exception, (
69+
httpx.ConnectError,
70+
httpx.ConnectTimeout,
71+
httpx.ReadTimeout,
72+
httpx.WriteTimeout,
73+
httpx.PoolTimeout,
74+
httpx.NetworkError,
75+
httpx.RemoteProtocolError,
76+
)):
77+
return True
78+
79+
if isinstance(exception, httpx.HTTPStatusError):
80+
# Retry on server errors that might be temporary
81+
return exception.response.status_code in [502, 503, 504]
82+
83+
return False
5984

6085
class FastAPI(BaseHTTPClient, ServerAPI):
6186
def __init__(self, system: System):
@@ -97,20 +122,38 @@ def __init__(self, system: System):
97122
self._session.headers[header] = value.get_secret_value()
98123

99124
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any:
100-
# If the request has json in kwargs, use orjson to serialize it,
101-
# remove it from kwargs, and add it to the content parameter
102-
# This is because httpx uses a slower json serializer
103-
if "json" in kwargs:
104-
data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY)
105-
kwargs["content"] = data
106-
107-
# Unlike requests, httpx does not automatically escape the path
108-
escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None)
109-
url = self._api_url + escaped_path
110-
111-
response = self._session.request(method, url, **cast(Any, kwargs))
112-
BaseHTTPClient._raise_chroma_error(response)
113-
return orjson.loads(response.text)
125+
@retry(
126+
stop=stop_after_attempt(3),
127+
wait=wait_exponential(
128+
multiplier=2,
129+
min=1,
130+
max=60
131+
),
132+
retry=retry_if_exception_type(is_retryable_exception),
133+
before_sleep=before_sleep_log(logger, logging.INFO),
134+
reraise=True
135+
)
136+
def _request_with_retry():
137+
# If the request has json in kwargs, use orjson to serialize it,
138+
# remove it from kwargs, and add it to the content parameter
139+
# This is because httpx uses a slower json serializer
140+
if "json" in kwargs:
141+
data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY)
142+
kwargs["content"] = data
143+
144+
# Unlike requests, httpx does not automatically escape the path
145+
escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None)
146+
url = self._api_url + escaped_path
147+
148+
response = self._session.request(method, url, **cast(Any, kwargs))
149+
BaseHTTPClient._raise_chroma_error(response)
150+
return orjson.loads(response.text)
151+
152+
try:
153+
return _request_with_retry()
154+
except RetryError as e:
155+
# Re-raise the last exception that caused the retry to fail
156+
raise e.last_attempt.exception() from None
114157

115158
@trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
116159
@override

0 commit comments

Comments
 (0)