Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion ntpd/src/daemon/sockets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::path::Path;

use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

const MAX_JSON_MESSAGE_SIZE: u64 = 1 << 20; // 1 MiB

pub async fn write_json<T>(stream: &mut (impl AsyncWrite + Unpin), value: &T) -> std::io::Result<()>
where
T: serde::Serialize,
Expand All @@ -20,7 +22,19 @@ where
T: serde::Deserialize<'a>,
{
buffer.clear();
let msg_size = stream.read_u64().await? as usize;
let msg_size = stream.read_u64().await?;
if msg_size > MAX_JSON_MESSAGE_SIZE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"message too large",
));
}
let msg_size: usize = msg_size.try_into().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"message size cannot be represented",
)
})?;
buffer.resize(msg_size, 0);
stream.read_exact(buffer).await?;
serde_json::from_slice(buffer)
Expand Down Expand Up @@ -142,4 +156,25 @@ mod tests {
// the logic will automatically grow the buffer to the required size
assert!(!buf.is_empty());
}

#[tokio::test]
async fn oversized_messages_are_rejected() {
// be careful with copying: tests run concurrently and should use a unique socket name!
let path = std::env::temp_dir().join(format!("ntp-test-stream-{}", alloc_port()));
if path.exists() {
std::fs::remove_file(&path).unwrap();
}
let listener = UnixListener::bind(&path).unwrap();
let mut writer = UnixStream::connect(&path).await.unwrap();

let (mut reader, _) = listener.accept().await.unwrap();

let oversized = MAX_JSON_MESSAGE_SIZE + 1;
writer.write_u64(oversized).await.unwrap();

let mut buf = Vec::new();
let err = read_json::<Vec<usize>>(&mut reader, &mut buf).await.unwrap_err();

assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
}
}
Loading