1
1
package apiserversdk
2
2
3
3
import (
4
+ "bytes"
5
+ "context"
6
+ "io"
4
7
"net/http"
5
8
"net/http/httputil"
6
9
"net/url"
7
10
"strings"
11
+ "time"
8
12
9
13
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
10
14
"k8s.io/apimachinery/pkg/util/net"
@@ -25,13 +29,16 @@ func NewMux(config MuxConfig) (*http.ServeMux, error) {
25
29
return nil , err
26
30
}
27
31
proxy := httputil .NewSingleHostReverseProxy (u )
28
- if proxy .Transport , err = rest .TransportFor (config .KubernetesConfig ); err != nil { // rest.TransportFor provides the auth to the K8s API server.
32
+ baseTransport , err := rest .TransportFor (config .KubernetesConfig )
33
+ if err != nil { // rest.TransportFor provides the auth to the K8s API server.
29
34
return nil , err
30
35
}
36
+ proxy .Transport = newRetryRoundTripper (baseTransport , 3 )
31
37
var handler http.Handler = proxy
32
38
if config .Middleware != nil {
33
39
handler = config .Middleware (proxy )
34
40
}
41
+ handler = bodyPreserveMiddleware (handler )
35
42
36
43
mux := http .NewServeMux ()
37
44
// TODO: add template features to specify routes.
@@ -84,3 +91,74 @@ func requireKubeRayService(handler http.Handler, k8sClient *kubernetes.Clientset
84
91
handler .ServeHTTP (w , r )
85
92
})
86
93
}
94
+
95
+ type retryRoundTripper struct {
96
+ base http.RoundTripper
97
+ retries int
98
+ }
99
+
100
+ func newRetryRoundTripper (base http.RoundTripper , retries int ) http.RoundTripper {
101
+ return & retryRoundTripper {base : base , retries : retries }
102
+ }
103
+
104
+ func (rrt * retryRoundTripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
105
+ timeoutCtx , cancel := context .WithTimeout (req .Context (), 30 * time .Second )
106
+ defer cancel ()
107
+
108
+ req = req .WithContext (timeoutCtx )
109
+
110
+ var resp * http.Response
111
+ var err error
112
+ for i := 0 ; i <= rrt .retries ; i ++ {
113
+ if i > 0 && req .GetBody != nil {
114
+ var bodyCopy io.ReadCloser
115
+ bodyCopy , err = req .GetBody ()
116
+ if err != nil {
117
+ return nil , err
118
+ }
119
+ req .Body = bodyCopy
120
+ }
121
+
122
+ resp , err = rrt .base .RoundTrip (req )
123
+ if err == nil {
124
+ return resp , nil
125
+ } else if ! shouldRetry (resp .StatusCode ) {
126
+ return resp , nil
127
+ }
128
+ if i < rrt .retries {
129
+ time .Sleep (time .Duration (1 << i ) * time .Second )
130
+ }
131
+ }
132
+ return resp , err
133
+ }
134
+
135
+ func shouldRetry (statusCode int ) bool {
136
+ switch statusCode {
137
+ case http .StatusInternalServerError ,
138
+ http .StatusBadGateway ,
139
+ http .StatusServiceUnavailable ,
140
+ http .StatusGatewayTimeout :
141
+ return true
142
+ default :
143
+ return false
144
+ }
145
+ }
146
+
147
+ func bodyPreserveMiddleware (h http.Handler ) http.Handler {
148
+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
149
+ if r .Body != nil && r .GetBody == nil {
150
+ bodyBytes , err := io .ReadAll (r .Body )
151
+ if err != nil {
152
+ http .Error (w , "failed to read request body" , http .StatusInternalServerError )
153
+ return
154
+ }
155
+ _ = r .Body .Close ()
156
+ r .Body = io .NopCloser (bytes .NewReader (bodyBytes ))
157
+ r .ContentLength = int64 (len (bodyBytes ))
158
+ r .GetBody = func () (io.ReadCloser , error ) {
159
+ return io .NopCloser (bytes .NewReader (bodyBytes )), nil
160
+ }
161
+ }
162
+ h .ServeHTTP (w , r )
163
+ })
164
+ }
0 commit comments