Skip to content

Commit c7cf891

Browse files
committed
fix: retry times error && add comments for clarity
Signed-off-by: Cheng-Yeh Chung <[email protected]>
1 parent bacb44b commit c7cf891

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

apiserversdk/proxy.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func NewMux(config MuxConfig) (*http.ServeMux, error) {
3434
if err != nil {
3535
return nil, fmt.Errorf("failed to get transport for config: %w", err)
3636
}
37-
proxy.Transport = newRetryRoundTripper(baseTransport, HTTPClientDefaultMaxRetry)
37+
proxy.Transport = newRetryRoundTripper(baseTransport)
3838
var handler http.Handler = proxy
3939
if config.Middleware != nil {
4040
handler = config.Middleware(proxy)
@@ -95,21 +95,26 @@ func requireKubeRayService(handler http.Handler, k8sClient *kubernetes.Clientset
9595
// retryRoundTripper is a custom implementation of http.RoundTripper that retries HTTP requests.
9696
// It verifies retryable HTTP status codes and retries using exponential backoff.
9797
type retryRoundTripper struct {
98-
base http.RoundTripper
99-
retries int
98+
base http.RoundTripper
99+
100+
// Num of retries after the initial attempt
101+
maxRetries int
100102
}
101103

102-
func newRetryRoundTripper(base http.RoundTripper, retries int) http.RoundTripper {
103-
return &retryRoundTripper{base: base, retries: retries}
104+
func newRetryRoundTripper(base http.RoundTripper) http.RoundTripper {
105+
return &retryRoundTripper{base: base, maxRetries: HTTPClientDefaultMaxRetry}
104106
}
105107

106108
func (rrt *retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
107109
ctx := req.Context()
108110

109111
var resp *http.Response
110112
var err error
111-
for attempt := 0; attempt < rrt.retries; attempt++ {
113+
for attempt := 0; attempt <= rrt.maxRetries; attempt++ {
114+
/* Try up to (rrt.maxRetries + 1) times: initial attempt + retries */
115+
112116
if attempt == 0 && req.Body != nil && req.GetBody == nil {
117+
/* Reuse request body in each attempt */
113118
bodyBytes, err := io.ReadAll(req.Body)
114119
if err != nil {
115120
return nil, fmt.Errorf("failed to read request body for retry support: %w", err)
@@ -146,7 +151,8 @@ func (rrt *retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
146151
return resp, nil
147152
}
148153

149-
if attempt < rrt.retries-1 && resp.Body != nil {
154+
if attempt < rrt.maxRetries && resp.Body != nil {
155+
/* If not last attempt, drain response body */
150156
if _, err = io.Copy(io.Discard, resp.Body); err != nil {
151157
return nil, fmt.Errorf("retryRoundTripper internal failure to drain response body: %w", err)
152158
}

apiserversdk/proxy_test.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,6 @@ var _ = Describe("kuberay service", Ordered, func() {
329329
})
330330

331331
var _ = Describe("retryRoundTripper", func() {
332-
const maxAttemps = 5
333-
334332
It("should not retry on successful status OK", func() {
335333
var attempts int32
336334
mock := &mockRoundTripper{
@@ -342,7 +340,7 @@ var _ = Describe("retryRoundTripper", func() {
342340
}, nil
343341
},
344342
}
345-
retrier := newRetryRoundTripper(mock, maxAttemps /*retries*/)
343+
retrier := newRetryRoundTripper(mock)
346344
req, err := http.NewRequest(http.MethodGet, "http://test", nil)
347345
Expect(err).ToNot(HaveOccurred())
348346
resp, err := retrier.RoundTrip(req)
@@ -352,12 +350,12 @@ var _ = Describe("retryRoundTripper", func() {
352350
})
353351

354352
It("should retry failed requests and eventually succeed", func() {
355-
const maxFailure = 3
353+
const maxFailure = 2
356354
var attempts int32
357355
mock := &mockRoundTripper{
358356
fn: func(_ *http.Request) (*http.Response, error) {
359357
count := atomic.AddInt32(&attempts, 1)
360-
if count < maxFailure {
358+
if count <= maxFailure {
361359
return &http.Response{
362360
StatusCode: http.StatusInternalServerError,
363361
Body: io.NopCloser(strings.NewReader("internal error")),
@@ -369,13 +367,13 @@ var _ = Describe("retryRoundTripper", func() {
369367
}, nil
370368
},
371369
}
372-
retrier := newRetryRoundTripper(mock, maxAttemps /*retries*/)
370+
retrier := newRetryRoundTripper(mock)
373371
req, err := http.NewRequest(http.MethodGet, "http://test", nil)
374372
Expect(err).ToNot(HaveOccurred())
375373
resp, err := retrier.RoundTrip(req)
376374
Expect(err).ToNot(HaveOccurred())
377375
Expect(resp.StatusCode).To(Equal(http.StatusOK))
378-
Expect(attempts).To(Equal(int32(maxFailure)))
376+
Expect(attempts).To(Equal(int32(maxFailure + 1)))
379377
})
380378

381379
It("Retries exceed maximum retry counts", func() {
@@ -389,13 +387,13 @@ var _ = Describe("retryRoundTripper", func() {
389387
}, nil
390388
},
391389
}
392-
retrier := newRetryRoundTripper(mock, maxAttemps /*retries*/)
390+
retrier := newRetryRoundTripper(mock)
393391
req, err := http.NewRequest(http.MethodGet, "http://test", nil)
394392
Expect(err).ToNot(HaveOccurred())
395393
resp, err := retrier.RoundTrip(req)
396394
Expect(err).ToNot(HaveOccurred())
397395
Expect(resp.StatusCode).To(Equal(http.StatusInternalServerError))
398-
Expect(attempts).To(Equal(int32(maxAttemps)))
396+
Expect(attempts).To(Equal(int32(HTTPClientDefaultMaxRetry + 1)))
399397
})
400398

401399
It("should not retry on non-retriable status", func() {
@@ -409,7 +407,7 @@ var _ = Describe("retryRoundTripper", func() {
409407
}, nil
410408
},
411409
}
412-
retrier := newRetryRoundTripper(mock, maxAttemps /*retries*/)
410+
retrier := newRetryRoundTripper(mock)
413411
req, err := http.NewRequest(http.MethodGet, "http://test", nil)
414412
Expect(err).ToNot(HaveOccurred())
415413
resp, err := retrier.RoundTrip(req)
@@ -428,7 +426,7 @@ var _ = Describe("retryRoundTripper", func() {
428426
}, nil
429427
},
430428
}
431-
retrier := newRetryRoundTripper(mock, maxAttemps /*retries*/)
429+
retrier := newRetryRoundTripper(mock)
432430
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
433431
defer cancel()
434432
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://test", nil)

0 commit comments

Comments
 (0)