Skip to content

Commit e2a73d9

Browse files
committed
feat: add retry and timeout to apiserver V2
Signed-off-by: Cheng-Yeh Chung <[email protected]>
1 parent a658405 commit e2a73d9

File tree

1 file changed

+79
-1
lines changed

1 file changed

+79
-1
lines changed

apiserversdk/proxy.go

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package apiserversdk
22

33
import (
4+
"bytes"
5+
"context"
6+
"io"
47
"net/http"
58
"net/http/httputil"
69
"net/url"
710
"strings"
11+
"time"
812

913
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1014
"k8s.io/apimachinery/pkg/util/net"
@@ -25,13 +29,16 @@ func NewMux(config MuxConfig) (*http.ServeMux, error) {
2529
return nil, err
2630
}
2731
proxy := httputil.NewSingleHostReverseProxy(u)
28-
if proxy.Transport, err = rest.TransportFor(config.KubernetesConfig); err != nil { // rest.TransportFor provides the auth to the K8s API server.
32+
baseTransport, err := rest.TransportFor(config.KubernetesConfig)
33+
if err != nil { // rest.TransportFor provides the auth to the K8s API server.
2934
return nil, err
3035
}
36+
proxy.Transport = newRetryRoundTripper(baseTransport, 3)
3137
var handler http.Handler = proxy
3238
if config.Middleware != nil {
3339
handler = config.Middleware(proxy)
3440
}
41+
handler = bodyPreserveMiddleware(handler)
3542

3643
mux := http.NewServeMux()
3744
// TODO: add template features to specify routes.
@@ -84,3 +91,74 @@ func requireKubeRayService(handler http.Handler, k8sClient *kubernetes.Clientset
8491
handler.ServeHTTP(w, r)
8592
})
8693
}
94+
95+
type retryRoundTripper struct {
96+
base http.RoundTripper
97+
retries int
98+
}
99+
100+
func newRetryRoundTripper(base http.RoundTripper, retries int) http.RoundTripper {
101+
return &retryRoundTripper{base: base, retries: retries}
102+
}
103+
104+
func (rrt *retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
105+
timeoutCtx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
106+
defer cancel()
107+
108+
req = req.WithContext(timeoutCtx)
109+
110+
var resp *http.Response
111+
var err error
112+
for i := 0; i <= rrt.retries; i++ {
113+
if i > 0 && req.GetBody != nil {
114+
var bodyCopy io.ReadCloser
115+
bodyCopy, err = req.GetBody()
116+
if err != nil {
117+
return nil, err
118+
}
119+
req.Body = bodyCopy
120+
}
121+
122+
resp, err = rrt.base.RoundTrip(req)
123+
if err == nil {
124+
return resp, nil
125+
} else if !shouldRetry(resp.StatusCode) {
126+
return resp, nil
127+
}
128+
if i < rrt.retries {
129+
time.Sleep(time.Duration(1<<i) * time.Second)
130+
}
131+
}
132+
return resp, err
133+
}
134+
135+
func shouldRetry(statusCode int) bool {
136+
switch statusCode {
137+
case http.StatusInternalServerError,
138+
http.StatusBadGateway,
139+
http.StatusServiceUnavailable,
140+
http.StatusGatewayTimeout:
141+
return true
142+
default:
143+
return false
144+
}
145+
}
146+
147+
func bodyPreserveMiddleware(h http.Handler) http.Handler {
148+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
149+
if r.Body != nil && r.GetBody == nil {
150+
bodyBytes, err := io.ReadAll(r.Body)
151+
if err != nil {
152+
http.Error(w, "failed to read request body", http.StatusInternalServerError)
153+
return
154+
}
155+
_ = r.Body.Close()
156+
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
157+
r.ContentLength = int64(len(bodyBytes))
158+
r.GetBody = func() (io.ReadCloser, error) {
159+
return io.NopCloser(bytes.NewReader(bodyBytes)), nil
160+
}
161+
}
162+
h.ServeHTTP(w, r)
163+
})
164+
}

0 commit comments

Comments
 (0)