diff --git a/core/src/core/pmmr/segment.rs b/core/src/core/pmmr/segment.rs index 8df8b8f989..cf5970b94d 100644 --- a/core/src/core/pmmr/segment.rs +++ b/core/src/core/pmmr/segment.rs @@ -15,14 +15,25 @@ //! Segment of a PMMR. use crate::core::hash::Hash; -use crate::core::pmmr::{self, Backend, ReadablePMMR, ReadonlyPMMR}; -use crate::ser::{Error, PMMRIndexHashable, PMMRable, Readable, Reader, Writeable, Writer}; +use crate::core::pmmr; +use crate::core::pmmr::{Backend, ReadablePMMR, ReadonlyPMMR}; +use crate::ser::{ + Error as SerError, PMMRIndexHashable, PMMRable, Readable, Reader, Writeable, Writer, +}; use croaring::Bitmap; +use log::error; use std::cmp::min; use std::fmt::Debug; -#[derive(Clone, Debug, Eq, PartialEq)] +// Temporary limits based on MAX_SEGMENT_HEIGHT = 11 in chain/src/pibd_params.rs +// Max leaves in a segment of height 11 = 2^11 = 2048 +// Max hashes in a segment of height 11 = 2^11 - 1 = 2047 +// We use 2048 for both as a safe upper bound. +const MAX_HASHES_PER_SEGMENT: usize = 2048; +const MAX_LEAVES_PER_SEGMENT: usize = 2048; + /// Possible segment types, according to this desegmenter +#[derive(Clone, Debug, Eq, PartialEq)] pub enum SegmentType { /// Output Bitmap Bitmap, @@ -81,7 +92,7 @@ pub struct SegmentIdentifier { } impl Readable for SegmentIdentifier { - fn read(reader: &mut R) -> Result { + fn read(reader: &mut R) -> Result { let height = reader.read_u8()?; let idx = reader.read_u64()?; Ok(Self { height, idx }) @@ -89,7 +100,7 @@ impl Readable for SegmentIdentifier { } impl Writeable for SegmentIdentifier { - fn write(&self, writer: &mut W) -> Result<(), Error> { + fn write(&self, writer: &mut W) -> Result<(), SerError> { writer.write_u8(self.height)?; writer.write_u64(self.idx) } @@ -458,7 +469,7 @@ where // Not full (only final segment): peaks in segment, bag them together let peaks = pmmr::peaks(mmr_size) .into_iter() - .filter(|&pos0| pos0 >= segment_first_pos && pos0 <= segment_last_pos) + .filter(|&x| x >= segment_first_pos && x <= segment_last_pos) .rev(); let mut hash = None; for pos0 in peaks { @@ -565,16 +576,27 @@ where } impl Readable for Segment { - fn read(reader: &mut R) -> Result { - let identifier = Readable::read(reader)?; + fn read(reader: &mut R) -> Result { + let identifier: SegmentIdentifier = Readable::read(reader)?; let n_hashes = reader.read_u64()? as usize; + + // Check against the maximum allowed size before allocating + if n_hashes > MAX_HASHES_PER_SEGMENT { + let err_msg = format!( + "Segment {:?} hash count {} exceeds limit {}", // Use {:?} + identifier, n_hashes, MAX_HASHES_PER_SEGMENT + ); + error!("PMMR Segment read error: {}", err_msg); + return Err(SerError::PMMRSegmentTooLarge(err_msg)); + } + let mut hash_pos = Vec::with_capacity(n_hashes); let mut last_pos = 0; for _ in 0..n_hashes { let pos = reader.read_u64()?; if pos <= last_pos { - return Err(Error::SortError); + return Err(SerError::SortError); } last_pos = pos; hash_pos.push(pos - 1); @@ -582,16 +604,26 @@ impl Readable for Segment { let mut hashes = Vec::::with_capacity(n_hashes); for _ in 0..n_hashes { - hashes.push(Readable::read(reader)?); + let hash: Hash = Readable::read(reader)?; + hashes.push(hash); } let n_leaves = reader.read_u64()? as usize; + // Also check leaves count for safety + if n_leaves > MAX_LEAVES_PER_SEGMENT { + let err_msg = format!( + "Segment {:?} leaf count {} exceeds limit {}", // Use {:?} + identifier, n_leaves, MAX_LEAVES_PER_SEGMENT + ); + error!("PMMR Segment read error: {}", err_msg); + return Err(SerError::PMMRSegmentTooLarge(err_msg)); + } let mut leaf_pos = Vec::with_capacity(n_leaves); last_pos = 0; for _ in 0..n_leaves { let pos = reader.read_u64()?; if pos <= last_pos { - return Err(Error::SortError); + return Err(SerError::SortError); } last_pos = pos; leaf_pos.push(pos - 1); @@ -616,7 +648,7 @@ impl Readable for Segment { } impl Writeable for Segment { - fn write(&self, writer: &mut W) -> Result<(), Error> { + fn write(&self, writer: &mut W) -> Result<(), SerError> { Writeable::write(&self.identifier, writer)?; writer.write_u64(self.hashes.len() as u64)?; for &pos in &self.hash_pos { @@ -822,7 +854,7 @@ impl SegmentProof { } impl Readable for SegmentProof { - fn read(reader: &mut R) -> Result { + fn read(reader: &mut R) -> Result { let n_hashes = reader.read_u64()? as usize; let mut hashes = Vec::with_capacity(n_hashes); for _ in 0..n_hashes { @@ -834,7 +866,7 @@ impl Readable for SegmentProof { } impl Writeable for SegmentProof { - fn write(&self, writer: &mut W) -> Result<(), Error> { + fn write(&self, writer: &mut W) -> Result<(), SerError> { writer.write_u64(self.hashes.len() as u64)?; for hash in &self.hashes { Writeable::write(hash, writer)?; diff --git a/core/src/ser.rs b/core/src/ser.rs index 38cd9aa1b7..b8635ef36d 100644 --- a/core/src/ser.rs +++ b/core/src/ser.rs @@ -72,6 +72,8 @@ pub enum Error { InvalidBlockVersion, /// Unsupported protocol version UnsupportedProtocolVersion, + /// PMMR segment size exceeds limits during deserialization + PMMRSegmentTooLarge(String), } impl From for Error { @@ -102,6 +104,7 @@ impl fmt::Display for Error { Error::HexError(ref e) => write!(f, "hex error {:?}", e), Error::InvalidBlockVersion => f.write_str("invalid block version"), Error::UnsupportedProtocolVersion => f.write_str("unsupported protocol version"), + Error::PMMRSegmentTooLarge(ref e) => write!(f, "PMMR segment too large: {}", e), } } } @@ -126,6 +129,7 @@ impl error::Error for Error { Error::HexError(_) => "hex error", Error::InvalidBlockVersion => "invalid block version", Error::UnsupportedProtocolVersion => "unsupported protocol version", + Error::PMMRSegmentTooLarge(_) => "PMMR segment too large", } } }