diff --git a/http.go b/http.go index b4d3cbdf79..8aa92a91a9 100644 --- a/http.go +++ b/http.go @@ -66,6 +66,7 @@ type Request struct { // Group bool members in order to reduce Request object size. parsedURI bool parsedPostArgs bool + uriParseErr error keepBodyBuffer bool @@ -146,12 +147,14 @@ func (req *Request) Host() []byte { func (req *Request) SetRequestURI(requestURI string) { req.Header.SetRequestURI(requestURI) req.parsedURI = false + req.uriParseErr = nil } // SetRequestURIBytes sets RequestURI. func (req *Request) SetRequestURIBytes(requestURI []byte) { req.Header.SetRequestURIBytes(requestURI) req.parsedURI = false + req.uriParseErr = nil } // RequestURI returns request's URI. @@ -891,6 +894,7 @@ func (req *Request) copyToSkipBody(dst *Request) { req.uri.CopyTo(&dst.uri) dst.parsedURI = req.parsedURI + dst.uriParseErr = req.uriParseErr req.postArgs.CopyTo(&dst.postArgs) dst.parsedPostArgs = req.parsedPostArgs @@ -960,19 +964,22 @@ func (req *Request) SetURI(newURI *URI) { if newURI != nil { newURI.CopyTo(&req.uri) req.parsedURI = true + req.uriParseErr = nil return } req.uri.Reset() req.parsedURI = false + req.uriParseErr = nil } func (req *Request) parseURI() error { if req.parsedURI { - return nil + return req.uriParseErr } - req.parsedURI = true - return req.uri.parse(req.Header.Host(), req.Header.RequestURI(), req.isTLS) + req.parsedURI = true + req.uriParseErr = req.uri.parse(req.Header.Host(), req.Header.RequestURI(), req.isTLS) + return req.uriParseErr } // PostArgs returns POST arguments. @@ -1146,6 +1153,7 @@ func (req *Request) resetSkipHeader() { req.ResetBody() req.uri.Reset() req.parsedURI = false + req.uriParseErr = nil req.postArgs.Reset() req.parsedPostArgs = false req.isTLS = false diff --git a/http_test.go b/http_test.go index c8e01ea60f..3fa84261d9 100644 --- a/http_test.go +++ b/http_test.go @@ -155,7 +155,8 @@ func testRequestCopyTo(t *testing.T, src *Request) { var dst Request src.CopyTo(&dst) - if !reflect.DeepEqual(src, &dst) { + // Compare serialized representations. + if src.String() != dst.String() || !bytes.Equal(src.Body(), dst.Body()) { t.Fatalf("RequestCopyTo fail, src: \n%+v\ndst: \n%+v\n", src, &dst) } } diff --git a/server.go b/server.go index f27d8de5da..0598e825ad 100644 --- a/server.go +++ b/server.go @@ -2349,11 +2349,21 @@ func (s *Server) serveConn(c net.Conn) error { writeTimeout = s.WriteTimeout } } - // read body - if s.StreamRequestBody { - err = ctx.Request.readBodyStream(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm) - } else { - err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm) + + if err == nil { + if err = ctx.Request.parseURI(); err != nil { + bw = s.writeErrorResponse(bw, ctx, serverName, err) + break + } + } + + if err == nil { + // read body + if s.StreamRequestBody { + err = ctx.Request.readBodyStream(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm) + } else { + err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm) + } } } // When StreamRequestBody is set to true, we cannot safely release br. diff --git a/server_test.go b/server_test.go index 03ce609d36..254e7f857c 100644 --- a/server_test.go +++ b/server_test.go @@ -16,6 +16,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "testing" "time" @@ -1882,6 +1883,38 @@ func TestServerHeadRequest(t *testing.T) { } } +func TestServerRejectsBackslashInAbsoluteURI(t *testing.T) { + t.Parallel() + + var handlerCalled atomic.Bool + s := &Server{ + Handler: func(ctx *RequestCtx) { + handlerCalled.Store(true) + ctx.Success("text/plain", []byte("ok")) + }, + } + + rw := &readWriter{} + rw.r.WriteString("GET http://vulndetector.com\\\\admin\\\\api HTTP/1.1\r\nHost: example.com\r\n\r\n") + + _ = s.ServeConn(rw) + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "Error when parsing request") + + if handlerCalled.Load() { + t.Fatal("handler should not run for invalid absolute URI") + } + + data, err := io.ReadAll(br) + if err != nil { + t.Fatalf("Unexpected error when reading remaining data: %v", err) + } + if len(data) > 0 { + t.Fatalf("unexpected remaining data %q", data) + } +} + func TestServerExpect100Continue(t *testing.T) { t.Parallel()