Skip to content

Commit 9d7986a

Browse files
committed
chore: refactors crl watcher
Signed-off-by: nilekh <[email protected]>
1 parent 01540f7 commit 9d7986a

File tree

8 files changed

+198
-258
lines changed

8 files changed

+198
-258
lines changed

src/proxy.rs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ pub struct Proxy {
173173
outbound: Outbound,
174174
socks5: Option<Socks5>,
175175
policy_watcher: PolicyWatcher,
176+
crl_watcher: Option<crate::tls::crl_watcher::CrlWatcher>,
176177
}
177178

178179
pub struct LocalWorkloadInformation {
@@ -302,10 +303,6 @@ impl Proxy {
302303
// We setup all the listeners first so we can capture any errors that should block startup
303304
let inbound = Inbound::new(pi.clone(), drain.clone()).await?;
304305

305-
if let Some(ref crl_mgr) = pi.crl_manager {
306-
crl_mgr.register_connection_manager(pi.connection_manager.clone());
307-
}
308-
309306
// This exists for `direct` integ tests, no other reason
310307
#[cfg(any(test, feature = "testing"))]
311308
if pi.cfg.fake_self_inbound {
@@ -326,15 +323,29 @@ impl Proxy {
326323
} else {
327324
None
328325
};
329-
let policy_watcher =
330-
PolicyWatcher::new(pi.state.clone(), drain, pi.connection_manager.clone());
326+
let policy_watcher = PolicyWatcher::new(
327+
pi.state.clone(),
328+
drain.clone(),
329+
pi.connection_manager.clone(),
330+
);
331+
332+
let crl_watcher = if let Some(ref crl_mgr) = pi.crl_manager {
333+
Some(crate::tls::crl_watcher::CrlWatcher::new(
334+
crl_mgr.clone(),
335+
drain,
336+
pi.connection_manager.clone(),
337+
))
338+
} else {
339+
None
340+
};
331341

332342
Ok(Proxy {
333343
inbound,
334344
inbound_passthrough,
335345
outbound,
336346
socks5,
337347
policy_watcher,
348+
crl_watcher,
338349
})
339350
}
340351

@@ -350,6 +361,10 @@ impl Proxy {
350361
tasks.push(tokio::spawn(socks5.run().in_current_span()));
351362
};
352363

364+
if let Some(crl_watcher) = self.crl_watcher {
365+
tasks.push(tokio::spawn(crl_watcher.run().in_current_span()));
366+
};
367+
353368
futures::future::join_all(tasks).await;
354369
}
355370

src/proxy/connection_manager.rs

Lines changed: 1 addition & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ impl ConnectionDrain {
5959
pub struct ConnectionManager {
6060
drains: Arc<RwLock<HashMap<InboundConnection, ConnectionDrain>>>,
6161
outbound_connections: Arc<RwLock<HashSet<OutboundConnection>>>,
62-
close_tx: Option<tokio::sync::mpsc::UnboundedSender<InboundConnection>>,
6362
}
6463

6564
impl std::fmt::Debug for ConnectionManager {
@@ -73,7 +72,6 @@ impl Default for ConnectionManager {
7372
ConnectionManager {
7473
drains: Arc::new(RwLock::new(HashMap::new())),
7574
outbound_connections: Arc::new(RwLock::new(HashSet::new())),
76-
close_tx: None,
7775
}
7876
}
7977
}
@@ -159,38 +157,6 @@ pub struct InboundConnection {
159157
}
160158

161159
impl ConnectionManager {
162-
// Initialize the connection manager with CRL support
163-
// This spawns an async task that handles connection closes from the CRL watcher
164-
pub fn new_with_crl_support() -> Self {
165-
let (close_tx, mut close_rx) = tokio::sync::mpsc::unbounded_channel::<InboundConnection>();
166-
167-
let cm = ConnectionManager {
168-
drains: Arc::new(RwLock::new(HashMap::new())),
169-
outbound_connections: Arc::new(RwLock::new(HashSet::new())),
170-
close_tx: Some(close_tx),
171-
};
172-
173-
// Spawn async task to handle close requests
174-
let cm_clone = ConnectionManager {
175-
drains: cm.drains.clone(),
176-
outbound_connections: cm.outbound_connections.clone(),
177-
close_tx: None,
178-
};
179-
180-
tokio::spawn(async move {
181-
while let Some(conn) = close_rx.recv().await {
182-
tracing::debug!(
183-
"Processing close request for connection from {}",
184-
conn.ctx.conn.src
185-
);
186-
cm_clone.close(&conn).await;
187-
}
188-
tracing::debug!("Connection close handler terminated");
189-
});
190-
191-
cm
192-
}
193-
194160
pub fn track_outbound(
195161
&self,
196162
src: SocketAddr,
@@ -306,7 +272,7 @@ impl ConnectionManager {
306272
}
307273

308274
// signal all connections listening to this channel to take action (typically terminate traffic)
309-
async fn close(&self, c: &InboundConnection) {
275+
pub async fn close(&self, c: &InboundConnection) {
310276
let drain = { self.drains.write().expect("mutex").remove(c) };
311277
if let Some(cd) = drain {
312278
cd.drain().await;
@@ -321,120 +287,6 @@ impl ConnectionManager {
321287
// potentially large copy under read lock, could require optimization
322288
self.drains.read().expect("mutex").keys().cloned().collect()
323289
}
324-
325-
/// Close only inbound connections whose client certificates have been revoked
326-
/// This is called when CRL updates with new revocations
327-
pub fn close_revoked_connections(&self, revoked_serials: &std::collections::HashSet<Vec<u8>>) {
328-
let conns = self.connections();
329-
330-
tracing::debug!(
331-
"checking {} active inbound connections against {} revoked serial(s)",
332-
conns.len(),
333-
revoked_serials.len()
334-
);
335-
336-
// Log all revoked serials we're checking against
337-
for (idx, serial) in revoked_serials.iter().enumerate() {
338-
tracing::debug!(
339-
"revoked serial {}: {} bytes (hex: {})",
340-
idx + 1,
341-
serial.len(),
342-
serial
343-
.iter()
344-
.map(|b| format!("{:02x}", b))
345-
.collect::<String>()
346-
);
347-
}
348-
349-
let mut closed_count = 0;
350-
let mut no_serial_count = 0;
351-
352-
for (idx, conn) in conns.iter().enumerate() {
353-
tracing::debug!(
354-
"connection {}: src={}, has_serials={}",
355-
idx + 1,
356-
conn.ctx.conn.src,
357-
conn.client_cert_serials.is_some()
358-
);
359-
360-
if let Some(ref serials) = conn.client_cert_serials {
361-
tracing::debug!(
362-
" connection {} has {} certificate(s) in chain",
363-
idx + 1,
364-
serials.len()
365-
);
366-
367-
// Check if ANY certificate in the chain is revoked
368-
let mut is_revoked = false;
369-
for (cert_idx, serial) in serials.iter().enumerate() {
370-
let serial_hex = serial
371-
.iter()
372-
.map(|b| format!("{:02x}", b))
373-
.collect::<String>();
374-
tracing::debug!(
375-
" cert {} serial: {} bytes (hex: {})",
376-
cert_idx,
377-
serial.len(),
378-
serial_hex
379-
);
380-
381-
if revoked_serials.contains(serial) {
382-
is_revoked = true;
383-
tracing::warn!(" cert {} serial MATCHES revoked serial!", cert_idx);
384-
break;
385-
}
386-
}
387-
388-
tracing::debug!(
389-
" connection {} has revoked certificate: {}",
390-
idx + 1,
391-
is_revoked
392-
);
393-
394-
if is_revoked {
395-
tracing::warn!(
396-
"closing inbound connection {} from {} due to revoked certificate in chain",
397-
idx + 1,
398-
conn.ctx.conn.src
399-
);
400-
401-
// Send close request through channel (works from blocking context)
402-
if let Some(ref tx) = self.close_tx {
403-
if let Err(e) = tx.send(conn.clone()) {
404-
tracing::error!("failed to send close request: {}", e);
405-
} else {
406-
closed_count += 1;
407-
}
408-
} else {
409-
tracing::warn!("CRL support not initialized - cannot close connection");
410-
}
411-
}
412-
} else {
413-
no_serial_count += 1;
414-
tracing::debug!(
415-
" connection {} has no client certificate serials (skipping)",
416-
idx + 1
417-
);
418-
}
419-
}
420-
421-
tracing::info!(
422-
"connection closure summary: {} total connections checked, {} with serials, {} without serials, {} closed",
423-
conns.len(),
424-
conns.len() - no_serial_count,
425-
no_serial_count,
426-
closed_count
427-
);
428-
429-
if closed_count > 0 {
430-
tracing::info!(
431-
"closed {} inbound connection(s) with revoked certificates",
432-
closed_count
433-
);
434-
} else {
435-
tracing::debug!("no active connections with newly revoked certificates");
436-
}
437-
}
438290
}
439291

440292
#[derive(serde::Serialize)]

src/proxy/h2/server.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use std::sync::Arc;
2525
use std::sync::atomic::{AtomicBool, Ordering};
2626
use tokio::net::TcpStream;
2727
use tokio::sync::{oneshot, watch};
28-
use tracing::{Instrument, debug};
28+
use tracing::{Instrument, debug, info};
2929

3030
pub struct H2Request {
3131
request: Parts,
@@ -99,6 +99,7 @@ pub async fn serve_connection<F, Fut>(
9999
s: tokio_rustls::server::TlsStream<TcpStream>,
100100
drain: DrainWatcher,
101101
mut force_shutdown: watch::Receiver<()>,
102+
tls_drain: Option<DrainWatcher>,
102103
handler: F,
103104
) -> Result<(), Error>
104105
where
@@ -169,6 +170,16 @@ where
169170
conn.graceful_shutdown();
170171
break;
171172
}
173+
_tls_shutdown = async {
174+
match &tls_drain {
175+
Some(d) => d.clone().wait_for_drain().await,
176+
None => std::future::pending().await,
177+
}
178+
} => {
179+
info!("starting graceful drain (TLS connection certificate revoked)");
180+
conn.graceful_shutdown();
181+
break;
182+
}
172183
}
173184
}
174185
// Signal to the ping_pong it should also stop.

0 commit comments

Comments
 (0)