Skip to content

Commit a104474

Browse files
committed
make time window generic
1 parent aae16f3 commit a104474

File tree

5 files changed

+77
-53
lines changed

5 files changed

+77
-53
lines changed

pkg/gofr/service/rate_limiter.go

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ var (
1717

1818
// RateLimiterConfig with custom keying support.
1919
type RateLimiterConfig struct {
20-
RequestsPerSecond float64 // Token refill rate (must be > 0)
21-
Burst int // Maximum burst capacity (must be > 0)
22-
KeyFunc func(*http.Request) string // Optional custom key extraction
23-
RedisClient *gofrRedis.Redis `json:"-"` // Optional Redis for distributed limiting
20+
Requests float64 // Number of requests allowed
21+
Window time.Duration // Time window (e.g., time.Minute, time.Hour)
22+
Burst int // Maximum burst capacity (must be > 0)
23+
KeyFunc func(*http.Request) string // Optional custom key extraction
24+
RedisClient *gofrRedis.Redis `json:"-"` // Optional Redis for distributed limiting
2425
}
2526

2627
// defaultKeyFunc extracts a normalized service key from an HTTP request.
@@ -53,8 +54,12 @@ func defaultKeyFunc(req *http.Request) string {
5354

5455
// Validate checks if the configuration is valid.
5556
func (config *RateLimiterConfig) Validate() error {
56-
if config.RequestsPerSecond <= 0 {
57-
return fmt.Errorf("%w: %f", errInvalidRequestRate, config.RequestsPerSecond)
57+
if config.Requests <= 0 {
58+
return fmt.Errorf("%w: %f", errInvalidRequestRate, config.Requests)
59+
}
60+
61+
if config.Window <= 0 {
62+
config.Window = time.Minute // Default: per-minute rate limiting
5863
}
5964

6065
if config.Burst <= 0 {
@@ -92,6 +97,12 @@ func (config *RateLimiterConfig) AddOption(h HTTP) HTTP {
9297
return NewLocalRateLimiter(*config, h)
9398
}
9499

100+
// RequestsPerSecond converts the configured rate to requests per second.
101+
func (config *RateLimiterConfig) RequestsPerSecond() float64 {
102+
// Convert any time window to "requests per second" for internal math
103+
return float64(config.Requests) / config.Window.Seconds()
104+
}
105+
95106
// RateLimitError represents a rate limiting error.
96107
type RateLimitError struct {
97108
ServiceKey string

pkg/gofr/service/rate_limiter_distributed.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,21 @@ import (
1616
const tokenBucketScript = `
1717
local key = KEYS[1]
1818
local burst = tonumber(ARGV[1])
19-
local refill_rate = tonumber(ARGV[2])
20-
local now = tonumber(ARGV[3])
19+
local requests = tonumber(ARGV[2])
20+
local window_seconds = tonumber(ARGV[3])
21+
local now = tonumber(ARGV[4])
22+
23+
-- Calculate refill rate as requests per second
24+
local refill_rate = requests / window_seconds
2125
2226
-- Fetch bucket
2327
local bucket = redis.call("HMGET", key, "tokens", "last_refill")
2428
local tokens = tonumber(bucket[1])
2529
local last_refill = tonumber(bucket[2])
2630
2731
if tokens == nil then
28-
tokens = burst
29-
last_refill = now
32+
tokens = burst
33+
last_refill = now
3034
end
3135
3236
-- Refill tokens
@@ -37,10 +41,10 @@ local allowed = 0
3741
local retryAfter = 0
3842
3943
if new_tokens >= 1 then
40-
allowed = 1
41-
new_tokens = new_tokens - 1
44+
allowed = 1
45+
new_tokens = new_tokens - 1
4246
else
43-
retryAfter = math.ceil((1 - new_tokens) / refill_rate * 1000) -- ms
47+
retryAfter = math.ceil((1 - new_tokens) / refill_rate * 1000) -- ms
4448
end
4549
4650
redis.call("HSET", key, "tokens", new_tokens, "last_refill", now)
@@ -108,7 +112,7 @@ func (rl *distributedRateLimiter) checkRateLimit(req *http.Request) error {
108112
tokenBucketScript,
109113
[]string{"gofr:ratelimit:" + serviceKey},
110114
rl.config.Burst,
111-
rl.config.RequestsPerSecond,
115+
int64(rl.config.Window.Seconds()),
112116
now,
113117
)
114118

pkg/gofr/service/rate_limiter_local.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,11 @@ func NewLocalRateLimiter(config RateLimiterConfig, h HTTP) HTTP {
7373
}
7474

7575
// newTokenBucket creates a new atomic token bucket with proper float64 scaling.
76-
func newTokenBucket(maxTokens int, refillRate float64) *tokenBucket {
77-
maxScaled := int64(maxTokens) * scale
76+
func newTokenBucket(config *RateLimiterConfig) *tokenBucket {
77+
maxScaled := int64(config.Burst) * scale
7878

79-
refillPerNanoFloat := refillRate * float64(scale) / float64(time.Second)
79+
requestsPerSecond := config.RequestsPerSecond()
80+
refillPerNanoFloat := requestsPerSecond * float64(scale) / float64(time.Second)
8081

8182
return &tokenBucket{
8283
tokens: maxScaled,
@@ -209,7 +210,7 @@ func (rl *localRateLimiter) checkRateLimit(req *http.Request) error {
209210
now := time.Now().Unix()
210211

211212
entry, _ := rl.buckets.LoadOrStore(serviceKey, &bucketEntry{
212-
bucket: newTokenBucket(rl.config.Burst, rl.config.RequestsPerSecond),
213+
bucket: newTokenBucket(&rl.config),
213214
lastAccess: now,
214215
})
215216

pkg/gofr/service/rate_limiter_local_test.go

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ func TestNewLocalRateLimiter_Basic(t *testing.T) {
6464
base := newBaseHTTPService(t, &hits)
6565

6666
rl := NewLocalRateLimiter(RateLimiterConfig{
67-
RequestsPerSecond: 5,
68-
Burst: 5,
69-
KeyFunc: func(*http.Request) string { return "svc-basic" },
67+
Requests: 5,
68+
Window: time.Second,
69+
Burst: 5,
70+
KeyFunc: func(*http.Request) string { return "svc-basic" },
7071
}, base)
7172

7273
resp, err := rl.Get(t.Context(), "/ok", nil)
@@ -86,9 +87,10 @@ func TestLocalRateLimiter_EnforceLimit(t *testing.T) {
8687
base := newBaseHTTPService(t, &hits)
8788

8889
rl := NewLocalRateLimiter(RateLimiterConfig{
89-
RequestsPerSecond: 1,
90-
Burst: 1,
91-
KeyFunc: func(*http.Request) string { return "svc-limit" },
90+
Requests: 1,
91+
Window: time.Second,
92+
Burst: 1,
93+
KeyFunc: func(*http.Request) string { return "svc-limit" },
9294
}, base)
9395

9496
resp, err := rl.Get(t.Context(), "/r1", nil)
@@ -125,9 +127,10 @@ func TestLocalRateLimiter_FractionalRPS(t *testing.T) {
125127
base := newBaseHTTPService(t, &hits)
126128

127129
rl := NewLocalRateLimiter(RateLimiterConfig{
128-
RequestsPerSecond: 0.5,
129-
Burst: 1,
130-
KeyFunc: func(*http.Request) string { return "svc-frac" },
130+
Requests: 0.5,
131+
Window: time.Second,
132+
Burst: 1,
133+
KeyFunc: func(*http.Request) string { return "svc-frac" },
131134
}, base)
132135

133136
resp, err := rl.Get(t.Context(), "/a", nil)
@@ -164,9 +167,10 @@ func TestLocalRateLimiter_CustomKey_SharedBucket(t *testing.T) {
164167
base := newBaseHTTPService(t, &hits)
165168

166169
rl := NewLocalRateLimiter(RateLimiterConfig{
167-
RequestsPerSecond: 1,
168-
Burst: 1,
169-
KeyFunc: func(*http.Request) string { return "shared-key" },
170+
Requests: 1,
171+
Window: time.Second,
172+
Burst: 1,
173+
KeyFunc: func(*http.Request) string { return "shared-key" },
170174
}, base)
171175

172176
resp, err := rl.Get(t.Context(), "/p1", nil)
@@ -206,9 +210,10 @@ func TestLocalRateLimiter_Concurrency(t *testing.T) {
206210
base := newBaseHTTPService(t, &hits)
207211

208212
rl := NewLocalRateLimiter(RateLimiterConfig{
209-
RequestsPerSecond: 1,
210-
Burst: 1,
211-
KeyFunc: func(*http.Request) string { return "svc-conc" },
213+
Requests: 1,
214+
Window: time.Second,
215+
Burst: 1,
216+
KeyFunc: func(*http.Request) string { return "svc-conc" },
212217
}, base)
213218

214219
const workers = 12
@@ -275,9 +280,10 @@ func TestLocalRateLimiter_NoMetrics(t *testing.T) {
275280
base := newBaseHTTPService(t, &hits)
276281

277282
rl := NewLocalRateLimiter(RateLimiterConfig{
278-
RequestsPerSecond: 2,
279-
Burst: 2,
280-
KeyFunc: func(*http.Request) string { return "svc-nometrics" },
283+
Requests: 2,
284+
Window: time.Second,
285+
Burst: 2,
286+
KeyFunc: func(*http.Request) string { return "svc-nometrics" },
281287
}, base)
282288

283289
resp, err := rl.Get(t.Context(), "/m", nil)
@@ -295,9 +301,10 @@ func TestLocalRateLimiter_RateLimitErrorFields(t *testing.T) {
295301
base := newBaseHTTPService(t, &hits)
296302

297303
rl := NewLocalRateLimiter(RateLimiterConfig{
298-
RequestsPerSecond: 0, // Always zero refill
299-
Burst: 1,
300-
KeyFunc: func(*http.Request) string { return "svc-zero" },
304+
Requests: 0, // Always zero refill
305+
Window: time.Second,
306+
Burst: 1,
307+
KeyFunc: func(*http.Request) string { return "svc-zero" },
301308
}, base)
302309

303310
resp, err := rl.Get(t.Context(), "/z1", nil)
@@ -330,16 +337,17 @@ func TestLocalRateLimiter_WrapperMethods_SuccessAndLimited(t *testing.T) {
330337

331338
// Success limiter: plenty of capacity
332339
successRL := NewLocalRateLimiter(RateLimiterConfig{
333-
RequestsPerSecond: 100,
334-
Burst: 100,
335-
KeyFunc: func(*http.Request) string { return "wrapper-allow" },
340+
Requests: 100,
341+
Window: time.Second,
342+
Burst: 100,
343+
KeyFunc: func(*http.Request) string { return "wrapper-allow" },
336344
}, base)
337345

338346
// Deny limiter: zero capacity (covers error branch)
339347
denyRL := NewLocalRateLimiter(RateLimiterConfig{
340-
RequestsPerSecond: 0,
341-
Burst: 0,
342-
KeyFunc: func(*http.Request) string { return "wrapper-deny" },
348+
Requests: 0,
349+
Burst: 0,
350+
KeyFunc: func(*http.Request) string { return "wrapper-deny" },
343351
}, base)
344352

345353
tests := []struct {

pkg/gofr/service/rate_limiter_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,21 @@ func newHTTPService(t *testing.T) *httpService {
3434

3535
func TestRateLimiterConfig_Validate(t *testing.T) {
3636
t.Run("invalid RPS", func(t *testing.T) {
37-
cfg := RateLimiterConfig{RequestsPerSecond: 0, Burst: 1}
37+
cfg := RateLimiterConfig{Requests: 0, Burst: 1}
3838
err := cfg.Validate()
3939
require.Error(t, err)
4040
assert.ErrorIs(t, err, errInvalidRequestRate)
4141
})
4242

4343
t.Run("invalid Burst", func(t *testing.T) {
44-
cfg := RateLimiterConfig{RequestsPerSecond: 1, Burst: 0}
44+
cfg := RateLimiterConfig{Requests: 1, Burst: 0}
4545
err := cfg.Validate()
4646
require.Error(t, err)
4747
assert.ErrorIs(t, err, errInvalidBurstSize)
4848
})
4949

5050
t.Run("sets default KeyFunc when nil", func(t *testing.T) {
51-
cfg := RateLimiterConfig{RequestsPerSecond: 1.5, Burst: 2}
51+
cfg := RateLimiterConfig{Requests: 1.5, Burst: 2}
5252
require.Nil(t, cfg.KeyFunc)
5353
require.NoError(t, cfg.Validate())
5454
require.NotNil(t, cfg.KeyFunc)
@@ -98,14 +98,14 @@ func TestDefaultKeyFunc(t *testing.T) {
9898

9999
func TestAddOption_InvalidConfigReturnsOriginal(t *testing.T) {
100100
h := newHTTPService(t)
101-
cfg := RateLimiterConfig{RequestsPerSecond: 0, Burst: 1} // invalid
101+
cfg := RateLimiterConfig{Requests: 0, Burst: 1} // invalid
102102
out := cfg.AddOption(h)
103103
assert.Same(t, h, out)
104104
}
105105

106106
func TestAddOption_LocalLimiter(t *testing.T) {
107107
h := newHTTPService(t)
108-
cfg := RateLimiterConfig{RequestsPerSecond: 2, Burst: 3}
108+
cfg := RateLimiterConfig{Requests: 2, Burst: 3}
109109
out := cfg.AddOption(h)
110110

111111
_, isLocal := out.(*localRateLimiter)
@@ -117,9 +117,9 @@ func TestAddOption_LocalLimiter(t *testing.T) {
117117
func TestAddOption_DistributedLimiter(t *testing.T) {
118118
h := newHTTPService(t)
119119
cfg := RateLimiterConfig{
120-
RequestsPerSecond: 5,
121-
Burst: 5,
122-
RedisClient: new(gofrRedis.Redis),
120+
Requests: 5,
121+
Burst: 5,
122+
RedisClient: new(gofrRedis.Redis),
123123
}
124124

125125
out := cfg.AddOption(h)

0 commit comments

Comments
 (0)