Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/edge/infra/guard/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ prometheus = "0.13.3"
rivet-config.workspace = true
rand = "0.8.5"
cluster.workspace = true
scc = "2.0.7"
moka = { version = "0.12", features = ["future"] }
pegboard.workspace = true
regex = "1.10.3"
futures-util = "0.3.30"
Expand Down
142 changes: 72 additions & 70 deletions packages/edge/infra/guard/core/src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use std::{
time::{Duration, Instant},
};

use tokio::sync::Mutex;

use bytes::Bytes;
use futures_util::{SinkExt, StreamExt};
use global_error::*;
Expand All @@ -18,8 +20,8 @@ use hyper::{Request, Response, StatusCode};
use hyper_tungstenite;
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use moka::future::Cache;
use rand;
use scc::HashMap as SccHashMap;
use serde_json;
use tokio::time::timeout;
use tracing::Instrument;
Expand All @@ -29,6 +31,8 @@ use uuid::Uuid;
use crate::metrics;

const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
const ROUTE_CACHE_TTL: Duration = Duration::from_secs(60 * 10); // 10 minutes
const PROXY_STATE_CACHE_TTL: Duration = Duration::from_secs(60 * 60); // 1 hour

// Routing types
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -154,38 +158,40 @@ pub type MiddlewareFn = Arc<

// Cache for routing results
struct RouteCache {
cache: SccHashMap<(String, String), RouteConfig>,
cache: Cache<(String, String), RouteConfig>,
}

impl RouteCache {
fn new() -> Self {
Self {
cache: SccHashMap::new(),
cache: Cache::builder()
.max_capacity(10_000)
.time_to_live(ROUTE_CACHE_TTL)
.build(),
}
}

#[tracing::instrument(skip_all)]
async fn get(&self, hostname: &str, path: &str) -> Option<RouteConfig> {
self.cache
.get_async(&(hostname.to_owned(), path.to_owned()))
.get(&(hostname.to_owned(), path.to_owned()))
.await
.map(|v| v.clone())
}

#[tracing::instrument(skip_all)]
async fn insert(&self, hostname: String, path: String, result: RouteConfig) {
self.cache.upsert_async((hostname, path), result).await;
self.cache.insert((hostname, path), result).await;

metrics::ROUTE_CACHE_SIZE.set(self.cache.len() as i64);
metrics::ROUTE_CACHE_SIZE.set(self.cache.entry_count() as i64);
}

#[tracing::instrument(skip_all)]
async fn purge(&self, hostname: &str, path: &str) {
self.cache
.remove_async(&(hostname.to_owned(), path.to_owned()))
.invalidate(&(hostname.to_owned(), path.to_owned()))
.await;

metrics::ROUTE_CACHE_SIZE.set(self.cache.len() as i64);
metrics::ROUTE_CACHE_SIZE.set(self.cache.entry_count() as i64);
}
}

Expand Down Expand Up @@ -257,8 +263,8 @@ pub struct ProxyState {
routing_fn: RoutingFn,
middleware_fn: MiddlewareFn,
route_cache: RouteCache,
rate_limiters: SccHashMap<(Uuid, std::net::IpAddr), RateLimiter>,
in_flight_counters: SccHashMap<(Uuid, std::net::IpAddr), InFlightCounter>,
rate_limiters: Cache<(Uuid, std::net::IpAddr), Arc<Mutex<RateLimiter>>>,
in_flight_counters: Cache<(Uuid, std::net::IpAddr), Arc<Mutex<InFlightCounter>>>,
port_type: PortType,
}

Expand All @@ -274,8 +280,14 @@ impl ProxyState {
routing_fn,
middleware_fn,
route_cache: RouteCache::new(),
rate_limiters: SccHashMap::new(),
in_flight_counters: SccHashMap::new(),
rate_limiters: Cache::builder()
.max_capacity(10_000)
.time_to_live(PROXY_STATE_CACHE_TTL)
.build(),
in_flight_counters: Cache::builder()
.max_capacity(10_000)
.time_to_live(PROXY_STATE_CACHE_TTL)
.build(),
port_type,
}
}
Expand Down Expand Up @@ -465,28 +477,29 @@ impl ProxyState {
let middleware_config = self.get_middleware_config(&actor_id).await?;

let cache_key = (actor_id, ip_addr);
let entry = self
.rate_limiters
.entry_async(cache_key)
.instrument(tracing::info_span!("entry_async"))
.await;
if let scc::hash_map::Entry::Occupied(mut entry) = entry {
// Key exists, get and mutate existing RateLimiter
let write_guard = entry.get_mut();
Ok(write_guard.try_acquire())

// Get existing limiter or create a new one
let limiter_arc = if let Some(existing_limiter) = self.rate_limiters.get(&cache_key).await {
existing_limiter
} else {
// Key doesn't exist, insert a new RateLimiter
let mut limiter = RateLimiter::new(
let new_limiter = Arc::new(Mutex::new(RateLimiter::new(
middleware_config.rate_limit.requests,
middleware_config.rate_limit.period,
);
let result = limiter.try_acquire();
entry.insert_entry(limiter);
)));
self.rate_limiters
.insert(cache_key, new_limiter.clone())
.await;
metrics::RATE_LIMITER_COUNT.set(self.rate_limiters.entry_count() as i64);
new_limiter
};

metrics::RATE_LIMITER_COUNT.set(self.rate_limiters.len() as i64);
// Try to acquire from the limiter
let result = {
let mut limiter = limiter_arc.lock().await;
limiter.try_acquire()
};

Ok(result)
}
Ok(result)
Comment on lines +496 to +502
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Potential deadlock risk if the limiter lock is held during await points. Consider using a timeout on the lock acquisition.

Suggested change
// Try to acquire from the limiter
let result = {
let mut limiter = limiter_arc.lock().await;
limiter.try_acquire()
};
Ok(result)
}
Ok(result)
// Try to acquire from the limiter with timeout
let result = match tokio::time::timeout(
Duration::from_secs(5),
limiter_arc.lock()
).await {
Ok(Ok(mut limiter)) => limiter.try_acquire(),
Ok(Err(_)) => false, // Lock poisoned
Err(_) => false, // Lock timeout
};
Ok(result)

}

#[tracing::instrument(skip_all)]
Expand All @@ -504,25 +517,29 @@ impl ProxyState {
let middleware_config = self.get_middleware_config(&actor_id).await?;

let cache_key = (actor_id, ip_addr);
let entry = self
.in_flight_counters
.entry_async(cache_key)
.instrument(tracing::info_span!("entry_async"))
.await;
if let scc::hash_map::Entry::Occupied(mut entry) = entry {
// Key exists, get and mutate existing InFlightCounter
let write_guard = entry.get_mut();
Ok(write_guard.try_acquire())
} else {
// Key doesn't exist, insert a new InFlightCounter
let mut counter = InFlightCounter::new(middleware_config.max_in_flight.amount);
let result = counter.try_acquire();
entry.insert_entry(counter);

metrics::IN_FLIGHT_COUNTER_COUNT.set(self.in_flight_counters.len() as i64);
// Get existing counter or create a new one
let counter_arc =
if let Some(existing_counter) = self.in_flight_counters.get(&cache_key).await {
existing_counter
} else {
let new_counter = Arc::new(Mutex::new(InFlightCounter::new(
middleware_config.max_in_flight.amount,
)));
self.in_flight_counters
.insert(cache_key, new_counter.clone())
.await;
metrics::IN_FLIGHT_COUNTER_COUNT.set(self.in_flight_counters.entry_count() as i64);
new_counter
};

// Try to acquire from the counter
let result = {
let mut counter = counter_arc.lock().await;
counter.try_acquire()
};

Ok(result)
}
Ok(result)
}

#[tracing::instrument(skip_all)]
Expand All @@ -534,12 +551,8 @@ impl ProxyState {
};

let cache_key = (actor_id, ip_addr);
if let Some(mut counter) = self
.in_flight_counters
.get_async(&cache_key)
.instrument(tracing::info_span!("get_async"))
.await
{
if let Some(counter_arc) = self.in_flight_counters.get(&cache_key).await {
let mut counter = counter_arc.lock().await;
counter.release();
}
}
Expand Down Expand Up @@ -651,24 +664,24 @@ impl ProxyService {
.status(StatusCode::TOO_MANY_REQUESTS)
.body(Full::<Bytes>::new(Bytes::new()))
.map_err(Into::into)
} else {
} else {
// Increment metrics
metrics::PROXY_REQUEST_PENDING
.with_label_values(&[&actor_id_str, &server_id_str, method_str, &path])
.inc();

metrics::PROXY_REQUEST_TOTAL
.with_label_values(&[&actor_id_str, &server_id_str, method_str, &path])
.inc();

// Prepare to release in-flight counter when done
let state_clone = self.state.clone();
crate::defer! {
tokio::spawn(async move {
state_clone.release_in_flight(client_ip, &actor_id).await;
}.instrument(tracing::info_span!("release_in_flight_task")));
}

// Branch for WebSocket vs HTTP handling
// Both paths will handle their own metrics and error handling
if hyper_tungstenite::is_upgrade_request(&req) {
Expand All @@ -688,20 +701,11 @@ impl ProxyService {
// Record metrics
let duration = start_time.elapsed();
metrics::PROXY_REQUEST_DURATION
.with_label_values(&[
&actor_id_str,
&server_id_str,
&status,
])
.with_label_values(&[&actor_id_str, &server_id_str, &status])
.observe(duration.as_secs_f64());

metrics::PROXY_REQUEST_PENDING
.with_label_values(&[
&actor_id_str,
&server_id_str,
method_str,
&path,
])
.with_label_values(&[&actor_id_str, &server_id_str, method_str, &path])
.dec();

res
Expand Down Expand Up @@ -1614,8 +1618,6 @@ impl ProxyService {
"Request received"
);

let start_time = Instant::now();

// Process the request
let result = self.handle_request(req).await;

Expand Down
Loading