Skip to content

Commit fcd1624

Browse files
committed
fix(guard): replace internal caches with moka
1 parent 37074b9 commit fcd1624

File tree

2 files changed

+73
-71
lines changed

2 files changed

+73
-71
lines changed

packages/edge/infra/guard/core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ prometheus = "0.13.3"
2727
rivet-config.workspace = true
2828
rand = "0.8.5"
2929
cluster.workspace = true
30-
scc = "2.0.7"
30+
moka = { version = "0.12", features = ["future"] }
3131
pegboard.workspace = true
3232
regex = "1.10.3"
3333
futures-util = "0.3.30"

packages/edge/infra/guard/core/src/proxy_service.rs

Lines changed: 72 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use std::{
88
time::{Duration, Instant},
99
};
1010

11+
use tokio::sync::Mutex;
12+
1113
use bytes::Bytes;
1214
use futures_util::{SinkExt, StreamExt};
1315
use global_error::*;
@@ -18,8 +20,8 @@ use hyper::{Request, Response, StatusCode};
1820
use hyper_tungstenite;
1921
use hyper_util::client::legacy::Client;
2022
use hyper_util::rt::TokioExecutor;
23+
use moka::future::Cache;
2124
use rand;
22-
use scc::HashMap as SccHashMap;
2325
use serde_json;
2426
use tokio::time::timeout;
2527
use tracing::Instrument;
@@ -29,6 +31,8 @@ use uuid::Uuid;
2931
use crate::metrics;
3032

3133
const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
34+
const ROUTE_CACHE_TTL: Duration = Duration::from_secs(60 * 10); // 10 minutes
35+
const PROXY_STATE_CACHE_TTL: Duration = Duration::from_secs(60 * 60); // 1 hour
3236

3337
// Routing types
3438
#[derive(Clone, Debug)]
@@ -154,38 +158,40 @@ pub type MiddlewareFn = Arc<
154158

155159
// Cache for routing results
156160
struct RouteCache {
157-
cache: SccHashMap<(String, String), RouteConfig>,
161+
cache: Cache<(String, String), RouteConfig>,
158162
}
159163

160164
impl RouteCache {
161165
fn new() -> Self {
162166
Self {
163-
cache: SccHashMap::new(),
167+
cache: Cache::builder()
168+
.max_capacity(10_000)
169+
.time_to_live(ROUTE_CACHE_TTL)
170+
.build(),
164171
}
165172
}
166173

167174
#[tracing::instrument(skip_all)]
168175
async fn get(&self, hostname: &str, path: &str) -> Option<RouteConfig> {
169176
self.cache
170-
.get_async(&(hostname.to_owned(), path.to_owned()))
177+
.get(&(hostname.to_owned(), path.to_owned()))
171178
.await
172-
.map(|v| v.clone())
173179
}
174180

175181
#[tracing::instrument(skip_all)]
176182
async fn insert(&self, hostname: String, path: String, result: RouteConfig) {
177-
self.cache.upsert_async((hostname, path), result).await;
183+
self.cache.insert((hostname, path), result).await;
178184

179-
metrics::ROUTE_CACHE_SIZE.set(self.cache.len() as i64);
185+
metrics::ROUTE_CACHE_SIZE.set(self.cache.entry_count() as i64);
180186
}
181187

182188
#[tracing::instrument(skip_all)]
183189
async fn purge(&self, hostname: &str, path: &str) {
184190
self.cache
185-
.remove_async(&(hostname.to_owned(), path.to_owned()))
191+
.invalidate(&(hostname.to_owned(), path.to_owned()))
186192
.await;
187193

188-
metrics::ROUTE_CACHE_SIZE.set(self.cache.len() as i64);
194+
metrics::ROUTE_CACHE_SIZE.set(self.cache.entry_count() as i64);
189195
}
190196
}
191197

@@ -257,8 +263,8 @@ pub struct ProxyState {
257263
routing_fn: RoutingFn,
258264
middleware_fn: MiddlewareFn,
259265
route_cache: RouteCache,
260-
rate_limiters: SccHashMap<(Uuid, std::net::IpAddr), RateLimiter>,
261-
in_flight_counters: SccHashMap<(Uuid, std::net::IpAddr), InFlightCounter>,
266+
rate_limiters: Cache<(Uuid, std::net::IpAddr), Arc<Mutex<RateLimiter>>>,
267+
in_flight_counters: Cache<(Uuid, std::net::IpAddr), Arc<Mutex<InFlightCounter>>>,
262268
port_type: PortType,
263269
}
264270

@@ -274,8 +280,14 @@ impl ProxyState {
274280
routing_fn,
275281
middleware_fn,
276282
route_cache: RouteCache::new(),
277-
rate_limiters: SccHashMap::new(),
278-
in_flight_counters: SccHashMap::new(),
283+
rate_limiters: Cache::builder()
284+
.max_capacity(10_000)
285+
.time_to_live(PROXY_STATE_CACHE_TTL)
286+
.build(),
287+
in_flight_counters: Cache::builder()
288+
.max_capacity(10_000)
289+
.time_to_live(PROXY_STATE_CACHE_TTL)
290+
.build(),
279291
port_type,
280292
}
281293
}
@@ -465,28 +477,29 @@ impl ProxyState {
465477
let middleware_config = self.get_middleware_config(&actor_id).await?;
466478

467479
let cache_key = (actor_id, ip_addr);
468-
let entry = self
469-
.rate_limiters
470-
.entry_async(cache_key)
471-
.instrument(tracing::info_span!("entry_async"))
472-
.await;
473-
if let scc::hash_map::Entry::Occupied(mut entry) = entry {
474-
// Key exists, get and mutate existing RateLimiter
475-
let write_guard = entry.get_mut();
476-
Ok(write_guard.try_acquire())
480+
481+
// Get existing limiter or create a new one
482+
let limiter_arc = if let Some(existing_limiter) = self.rate_limiters.get(&cache_key).await {
483+
existing_limiter
477484
} else {
478-
// Key doesn't exist, insert a new RateLimiter
479-
let mut limiter = RateLimiter::new(
485+
let new_limiter = Arc::new(Mutex::new(RateLimiter::new(
480486
middleware_config.rate_limit.requests,
481487
middleware_config.rate_limit.period,
482-
);
483-
let result = limiter.try_acquire();
484-
entry.insert_entry(limiter);
488+
)));
489+
self.rate_limiters
490+
.insert(cache_key, new_limiter.clone())
491+
.await;
492+
metrics::RATE_LIMITER_COUNT.set(self.rate_limiters.entry_count() as i64);
493+
new_limiter
494+
};
485495

486-
metrics::RATE_LIMITER_COUNT.set(self.rate_limiters.len() as i64);
496+
// Try to acquire from the limiter
497+
let result = {
498+
let mut limiter = limiter_arc.lock().await;
499+
limiter.try_acquire()
500+
};
487501

488-
Ok(result)
489-
}
502+
Ok(result)
490503
}
491504

492505
#[tracing::instrument(skip_all)]
@@ -504,25 +517,29 @@ impl ProxyState {
504517
let middleware_config = self.get_middleware_config(&actor_id).await?;
505518

506519
let cache_key = (actor_id, ip_addr);
507-
let entry = self
508-
.in_flight_counters
509-
.entry_async(cache_key)
510-
.instrument(tracing::info_span!("entry_async"))
511-
.await;
512-
if let scc::hash_map::Entry::Occupied(mut entry) = entry {
513-
// Key exists, get and mutate existing InFlightCounter
514-
let write_guard = entry.get_mut();
515-
Ok(write_guard.try_acquire())
516-
} else {
517-
// Key doesn't exist, insert a new InFlightCounter
518-
let mut counter = InFlightCounter::new(middleware_config.max_in_flight.amount);
519-
let result = counter.try_acquire();
520-
entry.insert_entry(counter);
521520

522-
metrics::IN_FLIGHT_COUNTER_COUNT.set(self.in_flight_counters.len() as i64);
521+
// Get existing counter or create a new one
522+
let counter_arc =
523+
if let Some(existing_counter) = self.in_flight_counters.get(&cache_key).await {
524+
existing_counter
525+
} else {
526+
let new_counter = Arc::new(Mutex::new(InFlightCounter::new(
527+
middleware_config.max_in_flight.amount,
528+
)));
529+
self.in_flight_counters
530+
.insert(cache_key, new_counter.clone())
531+
.await;
532+
metrics::IN_FLIGHT_COUNTER_COUNT.set(self.in_flight_counters.entry_count() as i64);
533+
new_counter
534+
};
535+
536+
// Try to acquire from the counter
537+
let result = {
538+
let mut counter = counter_arc.lock().await;
539+
counter.try_acquire()
540+
};
523541

524-
Ok(result)
525-
}
542+
Ok(result)
526543
}
527544

528545
#[tracing::instrument(skip_all)]
@@ -534,12 +551,8 @@ impl ProxyState {
534551
};
535552

536553
let cache_key = (actor_id, ip_addr);
537-
if let Some(mut counter) = self
538-
.in_flight_counters
539-
.get_async(&cache_key)
540-
.instrument(tracing::info_span!("get_async"))
541-
.await
542-
{
554+
if let Some(counter_arc) = self.in_flight_counters.get(&cache_key).await {
555+
let mut counter = counter_arc.lock().await;
543556
counter.release();
544557
}
545558
}
@@ -651,24 +664,24 @@ impl ProxyService {
651664
.status(StatusCode::TOO_MANY_REQUESTS)
652665
.body(Full::<Bytes>::new(Bytes::new()))
653666
.map_err(Into::into)
654-
} else {
667+
} else {
655668
// Increment metrics
656669
metrics::PROXY_REQUEST_PENDING
657670
.with_label_values(&[&actor_id_str, &server_id_str, method_str, &path])
658671
.inc();
659-
672+
660673
metrics::PROXY_REQUEST_TOTAL
661674
.with_label_values(&[&actor_id_str, &server_id_str, method_str, &path])
662675
.inc();
663-
676+
664677
// Prepare to release in-flight counter when done
665678
let state_clone = self.state.clone();
666679
crate::defer! {
667680
tokio::spawn(async move {
668681
state_clone.release_in_flight(client_ip, &actor_id).await;
669682
}.instrument(tracing::info_span!("release_in_flight_task")));
670683
}
671-
684+
672685
// Branch for WebSocket vs HTTP handling
673686
// Both paths will handle their own metrics and error handling
674687
if hyper_tungstenite::is_upgrade_request(&req) {
@@ -688,20 +701,11 @@ impl ProxyService {
688701
// Record metrics
689702
let duration = start_time.elapsed();
690703
metrics::PROXY_REQUEST_DURATION
691-
.with_label_values(&[
692-
&actor_id_str,
693-
&server_id_str,
694-
&status,
695-
])
704+
.with_label_values(&[&actor_id_str, &server_id_str, &status])
696705
.observe(duration.as_secs_f64());
697706

698707
metrics::PROXY_REQUEST_PENDING
699-
.with_label_values(&[
700-
&actor_id_str,
701-
&server_id_str,
702-
method_str,
703-
&path,
704-
])
708+
.with_label_values(&[&actor_id_str, &server_id_str, method_str, &path])
705709
.dec();
706710

707711
res
@@ -1614,8 +1618,6 @@ impl ProxyService {
16141618
"Request received"
16151619
);
16161620

1617-
let start_time = Instant::now();
1618-
16191621
// Process the request
16201622
let result = self.handle_request(req).await;
16211623

0 commit comments

Comments
 (0)