@@ -2,31 +2,97 @@ package httpmw
22
33import (
44 "fmt"
5+ "net"
56 "net/http"
7+ "strings"
8+ "sync"
69 "time"
710
811 "github.com/go-chi/httprate"
912
13+ "cdr.dev/slog"
1014 "github.com/coder/wgtunnel/tunneld/httpapi"
1115 "github.com/coder/wgtunnel/tunnelsdk"
1216)
1317
18+ type RateLimitConfig struct {
19+ Log slog.Logger
20+
21+ // Count of the amount of requests allowed in the Window. If the Count is
22+ // zero, the rate limiter is disabled.
23+ Count int
24+ Window time.Duration
25+
26+ // RealIPHeader is the header to use to get the real IP address of the
27+ // request. If this is empty, the request's RemoteAddr is used.
28+ RealIPHeader string
29+ }
30+
1431// RateLimit returns a handler that limits requests based on IP.
15- func RateLimit (count int , window time. Duration ) func (http.Handler ) http.Handler {
16- if count <= 0 {
32+ func RateLimit (cfg RateLimitConfig ) func (http.Handler ) http.Handler {
33+ if cfg . Count <= 0 {
1734 return func (handler http.Handler ) http.Handler {
1835 return handler
1936 }
2037 }
2138
39+ var logMissingHeaderOnce sync.Once
40+
2241 return httprate .Limit (
23- count ,
24- window ,
25- httprate .WithKeyByIP (),
42+ cfg .Count ,
43+ cfg .Window ,
44+ httprate .WithKeyFuncs (func (r * http.Request ) (string , error ) {
45+ if cfg .RealIPHeader != "" {
46+ val := r .Header .Get (cfg .RealIPHeader )
47+ if val != "" {
48+ val = strings .TrimSpace (strings .Split (val , "," )[0 ])
49+ return canonicalizeIP (val ), nil
50+ }
51+
52+ logMissingHeaderOnce .Do (func () {
53+ cfg .Log .Warn (r .Context (), "real IP header not found or invalid on request" , slog .F ("header" , cfg .RealIPHeader ), slog .F ("value" , val ))
54+ })
55+ }
56+
57+ return httprate .KeyByIP (r )
58+ }),
2659 httprate .WithLimitHandler (func (rw http.ResponseWriter , r * http.Request ) {
2760 httpapi .Write (r .Context (), rw , http .StatusTooManyRequests , tunnelsdk.Response {
28- Message : fmt .Sprintf ("You've been rate limited for sending more than %v requests in %v." , count , window ),
61+ Message : fmt .Sprintf ("You've been rate limited for sending more than %v requests in %v." , cfg . Count , cfg . Window ),
2962 })
3063 }),
3164 )
3265}
66+
67+ // canonicalizeIP returns a form of ip suitable for comparison to other IPs.
68+ // For IPv4 addresses, this is simply the whole string.
69+ // For IPv6 addresses, this is the /64 prefix.
70+ //
71+ // This function is taken directly from go-chi/httprate:
72+ // https://github.com/go-chi/httprate/blob/0ea2148d09a46ae62efcad05b70d87418d8e4f43/httprate.go#L111
73+ func canonicalizeIP (ip string ) string {
74+ isIPv6 := false
75+ // This is how net.ParseIP decides if an address is IPv6
76+ // https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/ip.go;l=704
77+ for i := 0 ; ! isIPv6 && i < len (ip ); i ++ {
78+ switch ip [i ] {
79+ case '.' :
80+ // IPv4
81+ return ip
82+ case ':' :
83+ // IPv6
84+ isIPv6 = true
85+ }
86+ }
87+ if ! isIPv6 {
88+ // Not an IP address at all
89+ return ip
90+ }
91+
92+ ipv6 := net .ParseIP (ip )
93+ if ipv6 == nil {
94+ return ip
95+ }
96+
97+ return ipv6 .Mask (net .CIDRMask (64 , 128 )).String ()
98+ }
0 commit comments