diff --git a/wsproxy/websocket_proxy.go b/wsproxy/websocket_proxy.go index 7092162..ebe884b 100644 --- a/wsproxy/websocket_proxy.go +++ b/wsproxy/websocket_proxy.go @@ -130,9 +130,12 @@ func defaultHeaderForwarder(header string) bool { // The cookie name is specified by the TokenCookieName value. // // example: -// Sec-Websocket-Protocol: Bearer, foobar +// +// Sec-Websocket-Protocol: Bearer, foobar +// // is converted to: -// Authorization: Bearer foobar +// +// Authorization: Bearer foobar // // Method can be overwritten with the MethodOverrideParam get parameter in the requested URL func WebsocketProxy(h http.Handler, opts ...Option) http.Handler { @@ -166,6 +169,9 @@ func isClosedConnError(err error) bool { func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { var responseHeader http.Header + var grpcMethodType string + var grpcGatewayBody string + // If Sec-WebSocket-Protocol starts with "Bearer", respond in kind. // TODO(tmc): consider customizability/extension point here. if strings.HasPrefix(r.Header.Get("Sec-WebSocket-Protocol"), "Bearer") { @@ -192,6 +198,8 @@ func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { if swsp := r.Header.Get("Sec-WebSocket-Protocol"); swsp != "" { request.Header.Set("Authorization", transformSubProtocolHeader(swsp)) } + grpcMethodType = r.Header.Get("x-grpc-method-type") + grpcGatewayBody = r.Header.Get("x-grpc-gateway-body") for header := range r.Header { if p.headerForwarder(header) { request.Header.Set(header, r.Header.Get(header)) @@ -233,6 +241,9 @@ func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { defer func() { cancelFn() }() + if grpcGatewayBody == "false" { + requestBodyW.Close() + } for { select { case <-ctx.Done(): @@ -259,6 +270,9 @@ func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { p.logger.Warnln("[read] error writing message to upstream http server:", err) return } + if grpcMethodType == "Unary" || grpcMethodType == "ServerStreaming" { + requestBodyW.Close() + } } }() // ping write loop @@ -303,6 +317,13 @@ func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { p.logger.Warnln("[write] error writing websocket message:", err) return } + if grpcMethodType == "Unary" || grpcMethodType == "ClientStreaming" { + // Close WebSocket + if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil { + p.logger.Warnln("[write] error writing websocket close message:", err) + return + } + } } if err := scanner.Err(); err != nil { p.logger.Warnln("scanner err:", err) @@ -338,12 +359,15 @@ func transformSubProtocolHeader(header string) string { func (w *inMemoryResponseWriter) Write(b []byte) (int, error) { return w.Writer.Write(b) } + func (w *inMemoryResponseWriter) Header() http.Header { return w.header } + func (w *inMemoryResponseWriter) WriteHeader(code int) { w.code = code } + func (w *inMemoryResponseWriter) CloseNotify() <-chan bool { return w.closed }