diff --git a/ipv6.go b/ipv6.go new file mode 100644 index 0000000000..1898c54afe --- /dev/null +++ b/ipv6.go @@ -0,0 +1,195 @@ +package fasthttp + +import ( + "bytes" + "errors" +) + +var ( + errInvalidIPv6Host = errors.New("invalid IPv6 host") + errInvalidIPv6Zone = errors.New("invalid IPv6 zone") + errInvalidIPv6Address = errors.New("invalid IPv6 address") +) + +func validateIPv6Literal(host []byte) error { + if len(host) == 0 || host[0] != '[' { + return nil + } + end := bytes.IndexByte(host, ']') + if end < 0 || end == 1 { + return errInvalidIPv6Host + } + addr := host[1:end] + + // Optional zone. + if zi := bytes.IndexByte(addr, '%'); zi >= 0 { + if zi == len(addr)-1 { + return errInvalidIPv6Zone + } + addr = addr[:zi] + } + + // Must have a colon to be IPv6. + if bytes.IndexByte(addr, ':') < 0 { + return errInvalidIPv6Address + } + + // IPv4-embedded? + if bytes.IndexByte(addr, '.') >= 0 { + lastColon := bytes.LastIndexByte(addr, ':') + if lastColon < 0 || lastColon == len(addr)-1 { + return errInvalidIPv6Address + } + + ipv4 := addr[lastColon+1:] + if !validIPv4(ipv4) { + return errInvalidIPv6Address + } + + head := addr[:lastColon] + seenDoubleAtSplit := lastColon > 0 && addr[lastColon-1] == ':' + if seenDoubleAtSplit { + head = addr[:lastColon-1] + } + + hextets, seenDoubleHead, ok := parseIPv6Hextets(head, false) + if !ok { + return errInvalidIPv6Address + } + + if seenDoubleHead && seenDoubleAtSplit { + return errInvalidIPv6Address + } + + hextets += 2 // IPv4 tail = 2 hextets + seenDouble := seenDoubleHead || seenDoubleAtSplit + + // '::' must compress at least one hextet. + if (!seenDouble && hextets != 8) || (seenDouble && hextets >= 8) { + return errInvalidIPv6Address + } + return nil + } + + // Pure IPv6 + hextets, seenDouble, ok := parseIPv6Hextets(addr, false) + if !ok { + return errInvalidIPv6Address + } + if (!seenDouble && hextets != 8) || (seenDouble && hextets >= 8) { + return errInvalidIPv6Address + } + return nil +} + +func parseIPv6Hextets(s []byte, allowTrailingColon bool) (groups int, seenDouble, ok bool) { + n := len(s) + if n == 0 { + return 0, false, true + } + i := 0 + justSawDouble := false + + for i < n { + if s[i] == ':' { + if i+1 < n && s[i+1] == ':' { + if seenDouble || justSawDouble { + return 0, false, false + } + seenDouble = true + justSawDouble = true + i += 2 + if i == n { + break + } + continue + } + if i == 0 { + return 0, false, false + } + if justSawDouble { + return 0, false, false + } + if i == n-1 { + if allowTrailingColon { + break + } + return 0, false, false + } + if !ishex(s[i+1]) { + return 0, false, false + } + i++ + continue + } + + justSawDouble = false + cnt := 0 + for cnt < 4 && i < n && ishex(s[i]) { + i++ + cnt++ + } + if cnt == 0 { + return 0, false, false + } + groups++ + + if i < n && s[i] != ':' { + return 0, false, false + } + } + return groups, seenDouble, true +} + +// validIPv4 validates a dotted-quad (exactly 4 parts, 0..255) with no leading zeros +// unless the octet is exactly "0". +func validIPv4(s []byte) bool { + parts := 0 + i := 0 + n := len(s) + + for parts < 4 { + if i >= n { + return false + } + + start := i + val := 0 + digits := 0 + + for i < n { + c := s[i] + if c < '0' || c > '9' { + break + } + val = val*10 + int(c-'0') + if val > 255 { + return false + } + i++ + digits++ + if digits > 3 { + return false + } + } + if digits == 0 { + return false + } + + // Disallow leading zeros like "00", "01", "001". + // Allowed: exactly "0" or any number that doesn't start with '0'. + if digits > 1 && s[start] == '0' { + return false + } + + parts++ + if parts == 4 { + return i == n // must consume all input + } + if i >= n || s[i] != '.' { + return false + } + i++ // skip dot + } + return false +} diff --git a/ipv6_test.go b/ipv6_test.go new file mode 100644 index 0000000000..c2641d2dba --- /dev/null +++ b/ipv6_test.go @@ -0,0 +1,85 @@ +package fasthttp + +import ( + "bytes" + "net" + "testing" +) + +// oracleValid replicates the original function's semantics using net.ParseIP: +// - Input must start with '[' +// - There must be a closing ']' and a non-empty address between +// - Optional %zone allowed but must not be empty +// - Zone is stripped before checking with net.ParseIP +// - Must contain a ':' to be IPv6 (prevents raw IPv4-in-brackets). +func oracleValid(host []byte) bool { + if len(host) == 0 || host[0] != '[' { + // Original function: non-bracketed hosts return nil (treated as valid/no-op). + return true + } + + end := bytes.IndexByte(host, ']') + if end < 0 { + return false + } + addr := host[1:end] + if len(addr) == 0 { + return false + } + + // Split off %zone (if present). + if zi := bytes.IndexByte(addr, '%'); zi >= 0 { + // Zone must not be empty. + if zi == len(addr)-1 { + return false + } + addr = addr[:zi] + } + + // Must contain ':' to be IPv6. + if bytes.IndexByte(addr, ':') < 0 { + return false + } + + // Use net.ParseIP on the de-zoned address (this was the original check). + if ip := net.ParseIP(string(addr)); ip == nil { + return false + } + return true +} + +func FuzzValidateIPv6Literal(f *testing.F) { + seeds := [][]byte{ + []byte(""), // non-bracketed => valid (no-op) + []byte("example.com"), // non-bracketed => valid (no-op) + []byte("["), // unterminated + []byte("[]"), // empty + []byte("[::]"), + []byte("[::1]"), + []byte("[2001:db8::1]"), + []byte("[2001:db8::]"), + []byte("[::ffff:192.168.0.1]"), + []byte("[fe80::1%eth0]"), + []byte("[fe80::1%]"), // empty zone + []byte("[1234]"), // no colon + []byte("[2001:db8:zzzz::1]"), // invalid hex + []byte("[::ffff:256.0.0.1]"), // invalid v4 tail + []byte("[2001:db8:::1]"), // triple colon + []byte("[::1]:443"), // trailing port outside ']' is ignored by validator + []byte("[2001:db8:0:0:0:0:2:1]"), + []byte("[2001:db8:0:0:0:0:2:1%en0]"), + } + for _, s := range seeds { + f.Add(s) + } + + f.Fuzz(func(t *testing.T, host []byte) { + gotErr := validateIPv6Literal(host) + wantValid := oracleValid(host) + + if (gotErr == nil) != wantValid { + t.Fatalf("mismatch for %q: validateIPv6Literal err=%v, oracleValid=%v", + b2s(host), gotErr, wantValid) + } + }) +} diff --git a/uri.go b/uri.go index b55fda0a96..a9627dc0d8 100644 --- a/uri.go +++ b/uri.go @@ -408,6 +408,9 @@ func parseHost(host []byte) ([]byte, error) { if host, err = unescape(host, encodeHost); err != nil { return nil, err } + if err = validateIPv6Literal(host); err != nil { + return nil, err + } return host, nil } diff --git a/uri_test.go b/uri_test.go index 8ee40f9ff8..2d6a6424e0 100644 --- a/uri_test.go +++ b/uri_test.go @@ -104,6 +104,33 @@ func testURIPathEscape(t *testing.T, path, expectedRequestURI string) { } } +func TestURIRejectInvalidIPv6(t *testing.T) { + t.Parallel() + + for _, raw := range []string{ + "http://[0:0::vulndetector.com]:80", + "http://[2001:db8::vulndetector.com]/", + "http://[vulndetector.com]/", + "http://[::ffff:192.0.2.300]/", + } { + var u URI + if err := u.Parse(nil, []byte(raw)); err == nil { + t.Errorf("expected Parse to fail for %q", raw) + } + } + + for _, raw := range []string{ + "http://[2001:db8::1]/", + "http://[fe80::1%25en0]/", + "http://[::ffff:192.0.2.1]/", + } { + var u URI + if err := u.Parse(nil, []byte(raw)); err != nil { + t.Errorf("unexpected error for %q: %v", raw, err) + } + } +} + func TestURIUpdate(t *testing.T) { t.Parallel()