Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f76b6b9
feat: add retry and timeout to apiserver V2
kenchung285 Jul 17, 2025
430f429
feat: add logs to failure retries
kenchung285 Jul 18, 2025
d0cab52
ensure timeout duration and retryable status codes consistent with ap…
kenchung285 Jul 18, 2025
043d99b
refactor comments
kenchung285 Jul 20, 2025
4807020
merge body preserving into retryRoundTripper
kenchung285 Jul 23, 2025
5949d16
add todo comment for v1 v2 compatibility
kenchung285 Jul 23, 2025
501975d
feat: drain response every time before retry
kenchung285 Jul 23, 2025
440730e
add TODO comment for merging common utils in v1 and v2
kenchung285 Jul 23, 2025
a10c8f5
test: add mock test for retry round tripper
kenchung285 Jul 23, 2025
08da214
Add comment for retryRoundTripper mock unit test
kenchung285 Jul 24, 2025
37e885b
feat: improve error log
kenchung285 Jul 24, 2025
73d7592
test: add more test for retryRoundTripper && fix minor errors
kenchung285 Jul 24, 2025
b7d776f
minor: format
kenchung285 Jul 24, 2025
5cffede
improve error message
kenchung285 Jul 24, 2025
5c1b1dd
renaming function
kenchung285 Jul 24, 2025
3c60df2
add todo comment for http util funtions
kenchung285 Jul 24, 2025
31129fc
fix: retry times error && add comments for clarity
kenchung285 Jul 25, 2025
501a64f
return timeout error for remaining time less than the backoff sleep d…
kenchung285 Jul 26, 2025
b783d41
refactor && use select to prevent blocking if the request context bei…
kenchung285 Jul 27, 2025
d2a3b72
refactor: avoid using req.GetBody in reusing request body
kenchung285 Jul 27, 2025
2d73386
update
kenchung285 Jul 27, 2025
2a0e301
Update apiserversdk/proxy.go
kenchung285 Jul 27, 2025
515ccdc
test: add a test for request with body
kenchung285 Jul 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions apiserversdk/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package apiserversdk

import "time"

// TODO: Make apiserver configs compatible with V1
const (
// Max retry times for HTTP Client
HTTPClientDefaultMaxRetry = 3

// Retry backoff settings
HTTPClientDefaultBackoffBase = float64(2)
HTTPClientDefaultInitBackoff = 500 * time.Millisecond
HTTPClientDefaultMaxBackoff = 10 * time.Second

// Overall timeout for retries
HTTPClientDefaultOverallTimeout = 30 * time.Second
)
122 changes: 119 additions & 3 deletions apiserversdk/proxy.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package apiserversdk

import (
"bytes"
"fmt"
"io"
"math"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/net"
Expand All @@ -22,12 +27,14 @@ type MuxConfig struct {
func NewMux(config MuxConfig) (*http.ServeMux, error) {
u, err := url.Parse(config.KubernetesConfig.Host) // parse the K8s API server URL from the KubernetesConfig.
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to parse url %s from config: %w", config.KubernetesConfig.Host, err)
}
proxy := httputil.NewSingleHostReverseProxy(u)
if proxy.Transport, err = rest.TransportFor(config.KubernetesConfig); err != nil { // rest.TransportFor provides the auth to the K8s API server.
return nil, err
baseTransport, err := rest.TransportFor(config.KubernetesConfig) // rest.TransportFor provides the auth to the K8s API server.
if err != nil {
return nil, fmt.Errorf("failed to get transport for config: %w", err)
}
proxy.Transport = newRetryRoundTripper(baseTransport)
var handler http.Handler = proxy
if config.Middleware != nil {
handler = config.Middleware(proxy)
Expand Down Expand Up @@ -84,3 +91,112 @@ func requireKubeRayService(handler http.Handler, k8sClient *kubernetes.Clientset
handler.ServeHTTP(w, r)
})
}

// retryRoundTripper is a custom implementation of http.RoundTripper that retries HTTP requests.
// It verifies retryable HTTP status codes and retries using exponential backoff.
type retryRoundTripper struct {
base http.RoundTripper

// Num of retries after the initial attempt
maxRetries int
}

func newRetryRoundTripper(base http.RoundTripper) http.RoundTripper {
return &retryRoundTripper{base: base, maxRetries: HTTPClientDefaultMaxRetry}
}

func (rrt *retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()

var bodyBytes []byte
var resp *http.Response
var err error

if req.Body != nil {
/* Reuse request body in each attempt */
bodyBytes, err = io.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body for retry support: %w", err)
}
err = req.Body.Close()
if err != nil {
return nil, fmt.Errorf("failed to close request body: %w", err)
}
}

for attempt := 0; attempt <= rrt.maxRetries; attempt++ {
/* Try up to (rrt.maxRetries + 1) times: initial attempt + retries */

if bodyBytes != nil {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}

resp, err = rrt.base.RoundTrip(req)
if err != nil {
return resp, fmt.Errorf("request to %s %s failed with error: %w", req.Method, req.URL.String(), err)
}

if isSuccessfulStatusCode(resp.StatusCode) {
return resp, nil
}

if !isRetryableHTTPStatusCodes(resp.StatusCode) {
return resp, nil
}

if attempt == rrt.maxRetries {
return resp, nil
}

if resp.Body != nil {
/* If not last attempt, drain response body */
if _, err = io.Copy(io.Discard, resp.Body); err != nil {
return nil, fmt.Errorf("retryRoundTripper internal failure to drain response body: %w", err)
}
if err = resp.Body.Close(); err != nil {
return nil, fmt.Errorf("retryRoundTripper internal failure to close response body: %w", err)
}
}

// TODO: move to HTTP util function in independent util file
sleepDuration := HTTPClientDefaultInitBackoff * time.Duration(math.Pow(HTTPClientDefaultBackoffBase, float64(attempt)))
if sleepDuration > HTTPClientDefaultMaxBackoff {
sleepDuration = HTTPClientDefaultMaxBackoff
}

// TODO: merge common utils for apiserver v1 and v2
if deadline, ok := ctx.Deadline(); ok {
remaining := time.Until(deadline)
if sleepDuration > remaining {
return resp, fmt.Errorf("retry timeout exceeded context deadline")
}
}

select {
case <-time.After(sleepDuration):
case <-ctx.Done():
return resp, fmt.Errorf("retry canceled during backoff: %w", ctx.Err())
}
}
return resp, err
}

// TODO: move HTTP util function into independent util file / folder
func isSuccessfulStatusCode(statusCode int) bool {
return 200 <= statusCode && statusCode < 300
}

// TODO: merge common utils for apiserver v1 and v2
func isRetryableHTTPStatusCodes(statusCode int) bool {
switch statusCode {
case http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
http.StatusInternalServerError, // 500
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout: // 504
return true
default:
return false
}
}
154 changes: 154 additions & 0 deletions apiserversdk/proxy_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package apiserversdk

import (
"bytes"
"context"
"errors"
"io"
"net"
"net/http"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -325,3 +328,154 @@ var _ = Describe("kuberay service", Ordered, func() {
})
})
})

var _ = Describe("retryRoundTripper", func() {
It("should not retry on successful status OK", func() {
var attempts int32
mock := &mockRoundTripper{
fn: func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&attempts, 1)
return &http.Response{ /* Always return OK status */
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("OK")),
}, nil
},
}
retrier := newRetryRoundTripper(mock)
req, err := http.NewRequest(http.MethodGet, "http://test", nil)
Expect(err).ToNot(HaveOccurred())
resp, err := retrier.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
Expect(attempts).To(Equal(int32(1)))
})

It("should retry failed requests and eventually succeed", func() {
const maxFailure = 2
var attempts int32
mock := &mockRoundTripper{
fn: func(_ *http.Request) (*http.Response, error) {
count := atomic.AddInt32(&attempts, 1)
if count <= maxFailure {
return &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader("internal error")),
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("ok")),
}, nil
},
}
retrier := newRetryRoundTripper(mock)
req, err := http.NewRequest(http.MethodGet, "http://test", nil)
Expect(err).ToNot(HaveOccurred())
resp, err := retrier.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
Expect(attempts).To(Equal(int32(maxFailure + 1)))
})

It("Retries exceed maximum retry counts", func() {
var attempts int32
mock := &mockRoundTripper{
fn: func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&attempts, 1)
return &http.Response{ /* Always return retriable status */
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader("internal error")),
}, nil
},
}
retrier := newRetryRoundTripper(mock)
req, err := http.NewRequest(http.MethodGet, "http://test", nil)
Expect(err).ToNot(HaveOccurred())
resp, err := retrier.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusInternalServerError))
Expect(attempts).To(Equal(int32(HTTPClientDefaultMaxRetry + 1)))
})

It("Retries on request with body", func() {
const testBody = "test-body"
const maxFailure = 2
var attempts int32
mock := &mockRoundTripper{
fn: func(req *http.Request) (*http.Response, error) {
count := atomic.AddInt32(&attempts, 1)
reqBody, err := io.ReadAll(req.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(reqBody)).To(Equal(testBody))

if count <= maxFailure {
return &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader("internal error")),
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("ok")),
}, nil
},
}
retrier := newRetryRoundTripper(mock)
body := bytes.NewBufferString(testBody)
req, err := http.NewRequest(http.MethodPost, "http://test", body)
Expect(err).ToNot(HaveOccurred())
resp, err := retrier.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
Expect(attempts).To(Equal(int32(maxFailure + 1)))
})

It("should not retry on non-retriable status", func() {
var attempts int32
mock := &mockRoundTripper{
fn: func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&attempts, 1)
return &http.Response{ /* Always return non-retriable status */
StatusCode: http.StatusNotFound,
Body: io.NopCloser(strings.NewReader("Not Found")),
}, nil
},
}
retrier := newRetryRoundTripper(mock)
req, err := http.NewRequest(http.MethodGet, "http://test", nil)
Expect(err).ToNot(HaveOccurred())
resp, err := retrier.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusNotFound))
Expect(attempts).To(Equal(int32(1)))
})

It("should respect context timeout and stop retrying", func() {
mock := &mockRoundTripper{
fn: func(_ *http.Request) (*http.Response, error) {
time.Sleep(100 * time.Millisecond)
return &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader("internal error")),
}, nil
},
}
retrier := newRetryRoundTripper(mock)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://test", nil)
Expect(err).ToNot(HaveOccurred())
resp, err := retrier.RoundTrip(req)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("retry timeout exceeded context deadline"))
Expect(resp).ToNot(BeNil())
})
})

type mockRoundTripper struct {
fn func(*http.Request) (*http.Response, error)
}

func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return m.fn(req)
}
Loading