From e43a118539a580c7e997dff5bbe1a5c114c5c6da Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 27 Oct 2025 19:54:20 -0400 Subject: [PATCH] Cap control socket payload size --- ntpd/src/daemon/sockets.rs | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/ntpd/src/daemon/sockets.rs b/ntpd/src/daemon/sockets.rs index ecc2357c3..c6eed8d10 100644 --- a/ntpd/src/daemon/sockets.rs +++ b/ntpd/src/daemon/sockets.rs @@ -3,6 +3,8 @@ use std::path::Path; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +const MAX_JSON_MESSAGE_SIZE: u64 = 1 << 20; // 1 MiB + pub async fn write_json(stream: &mut (impl AsyncWrite + Unpin), value: &T) -> std::io::Result<()> where T: serde::Serialize, @@ -20,7 +22,19 @@ where T: serde::Deserialize<'a>, { buffer.clear(); - let msg_size = stream.read_u64().await? as usize; + let msg_size = stream.read_u64().await?; + if msg_size > MAX_JSON_MESSAGE_SIZE { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "message too large", + )); + } + let msg_size: usize = msg_size.try_into().map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "message size cannot be represented", + ) + })?; buffer.resize(msg_size, 0); stream.read_exact(buffer).await?; serde_json::from_slice(buffer) @@ -142,4 +156,25 @@ mod tests { // the logic will automatically grow the buffer to the required size assert!(!buf.is_empty()); } + + #[tokio::test] + async fn oversized_messages_are_rejected() { + // be careful with copying: tests run concurrently and should use a unique socket name! + let path = std::env::temp_dir().join(format!("ntp-test-stream-{}", alloc_port())); + if path.exists() { + std::fs::remove_file(&path).unwrap(); + } + let listener = UnixListener::bind(&path).unwrap(); + let mut writer = UnixStream::connect(&path).await.unwrap(); + + let (mut reader, _) = listener.accept().await.unwrap(); + + let oversized = MAX_JSON_MESSAGE_SIZE + 1; + writer.write_u64(oversized).await.unwrap(); + + let mut buf = Vec::new(); + let err = read_json::>(&mut reader, &mut buf).await.unwrap_err(); + + assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput); + } }