|
| 1 | +// SPDX-FileCopyrightText: © 2024-2025 Phala Network <[email protected]> |
| 2 | +// |
| 3 | +// SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +use std::net::SocketAddr; |
| 6 | + |
| 7 | +use anyhow::{bail, Context, Result}; |
| 8 | +use proxy_protocol::{version1 as v1, version2 as v2, ProxyHeader}; |
| 9 | +use tokio::{ |
| 10 | + io::{AsyncRead, AsyncReadExt}, |
| 11 | + net::TcpStream, |
| 12 | +}; |
| 13 | + |
| 14 | +use crate::config::ProxyConfig; |
| 15 | + |
| 16 | +const V1_PROTOCOL_PREFIX: &str = "PROXY"; |
| 17 | +const V1_PREFIX_LEN: usize = 5; |
| 18 | +const V1_MAX_LENGTH: usize = 107; |
| 19 | +const V1_TERMINATOR: &[u8] = b"\r\n"; |
| 20 | + |
| 21 | +const V2_PROTOCOL_PREFIX: &[u8] = b"\r\n\r\n\0\r\nQUIT\n"; |
| 22 | +const V2_PREFIX_LEN: usize = 12; |
| 23 | +const V2_MINIMUM_LEN: usize = 16; |
| 24 | +const V2_LENGTH_INDEX: usize = 14; |
| 25 | +const READ_BUFFER_LEN: usize = 512; |
| 26 | +const V2_MAX_LENGTH: usize = 2048; |
| 27 | + |
| 28 | +pub(crate) async fn get_inbound_pp_header( |
| 29 | + inbound: TcpStream, |
| 30 | + config: &ProxyConfig, |
| 31 | +) -> Result<(TcpStream, ProxyHeader)> { |
| 32 | + if config.inbound_pp_enabled { |
| 33 | + read_proxy_header(inbound).await |
| 34 | + } else { |
| 35 | + let header = create_inbound_pp_header(&inbound); |
| 36 | + Ok((inbound, header)) |
| 37 | + } |
| 38 | +} |
| 39 | + |
| 40 | +pub struct DisplayAddr<'a>(pub &'a ProxyHeader); |
| 41 | + |
| 42 | +impl std::fmt::Display for DisplayAddr<'_> { |
| 43 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 44 | + match self.0 { |
| 45 | + ProxyHeader::Version2 { addresses, .. } => match addresses { |
| 46 | + v2::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source), |
| 47 | + v2::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source), |
| 48 | + v2::ProxyAddresses::Unix { .. } => write!(f, "<unix>"), |
| 49 | + v2::ProxyAddresses::Unspec => write!(f, "<unspec>"), |
| 50 | + }, |
| 51 | + ProxyHeader::Version1 { addresses, .. } => match addresses { |
| 52 | + v1::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source), |
| 53 | + v1::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source), |
| 54 | + v1::ProxyAddresses::Unknown => write!(f, "<unknown>"), |
| 55 | + }, |
| 56 | + _ => write!(f, "<unknown ver>"), |
| 57 | + } |
| 58 | + } |
| 59 | +} |
| 60 | + |
| 61 | +fn create_inbound_pp_header(inbound: &TcpStream) -> ProxyHeader { |
| 62 | + // When PROXY protocol is disabled, create a synthetic header from the actual TCP connection |
| 63 | + let peer_addr = inbound.peer_addr().ok(); |
| 64 | + let local_addr = inbound.local_addr().ok(); |
| 65 | + |
| 66 | + match (peer_addr, local_addr) { |
| 67 | + (Some(SocketAddr::V4(source)), Some(SocketAddr::V4(destination))) => { |
| 68 | + ProxyHeader::Version2 { |
| 69 | + command: v2::ProxyCommand::Proxy, |
| 70 | + transport_protocol: v2::ProxyTransportProtocol::Stream, |
| 71 | + addresses: v2::ProxyAddresses::Ipv4 { |
| 72 | + source, |
| 73 | + destination, |
| 74 | + }, |
| 75 | + } |
| 76 | + } |
| 77 | + (Some(SocketAddr::V6(source)), Some(SocketAddr::V6(destination))) => { |
| 78 | + ProxyHeader::Version2 { |
| 79 | + command: v2::ProxyCommand::Proxy, |
| 80 | + transport_protocol: v2::ProxyTransportProtocol::Stream, |
| 81 | + addresses: v2::ProxyAddresses::Ipv6 { |
| 82 | + source, |
| 83 | + destination, |
| 84 | + }, |
| 85 | + } |
| 86 | + } |
| 87 | + _ => ProxyHeader::Version2 { |
| 88 | + command: v2::ProxyCommand::Proxy, |
| 89 | + transport_protocol: v2::ProxyTransportProtocol::Stream, |
| 90 | + addresses: v2::ProxyAddresses::Unspec, |
| 91 | + }, |
| 92 | + } |
| 93 | +} |
| 94 | + |
| 95 | +async fn read_proxy_header<I>(mut stream: I) -> Result<(I, ProxyHeader)> |
| 96 | +where |
| 97 | + I: AsyncRead + Unpin, |
| 98 | +{ |
| 99 | + let mut buffer = [0; READ_BUFFER_LEN]; |
| 100 | + let mut dynamic_buffer = None; |
| 101 | + |
| 102 | + stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?; |
| 103 | + |
| 104 | + if &buffer[..V1_PREFIX_LEN] == V1_PROTOCOL_PREFIX.as_bytes() { |
| 105 | + read_v1_header(&mut stream, &mut buffer).await?; |
| 106 | + } else { |
| 107 | + stream |
| 108 | + .read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN]) |
| 109 | + .await?; |
| 110 | + if &buffer[..V2_PREFIX_LEN] == V2_PROTOCOL_PREFIX { |
| 111 | + dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?; |
| 112 | + } else { |
| 113 | + bail!("No valid Proxy Protocol header detected"); |
| 114 | + } |
| 115 | + } |
| 116 | + |
| 117 | + let mut buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]); |
| 118 | + |
| 119 | + let header = |
| 120 | + proxy_protocol::parse(&mut buffer).context("failed to parse proxy protocol header")?; |
| 121 | + Ok((stream, header)) |
| 122 | +} |
| 123 | + |
| 124 | +async fn read_v2_header<I>( |
| 125 | + mut stream: I, |
| 126 | + buffer: &mut [u8; READ_BUFFER_LEN], |
| 127 | +) -> Result<Option<Vec<u8>>> |
| 128 | +where |
| 129 | + I: AsyncRead + Unpin, |
| 130 | +{ |
| 131 | + let length = |
| 132 | + u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize; |
| 133 | + let full_length = V2_MINIMUM_LEN + length; |
| 134 | + |
| 135 | + if full_length > V2_MAX_LENGTH { |
| 136 | + bail!("V2 Proxy Protocol header is too long"); |
| 137 | + } |
| 138 | + |
| 139 | + if full_length > READ_BUFFER_LEN { |
| 140 | + let mut dynamic_buffer = Vec::with_capacity(full_length); |
| 141 | + dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]); |
| 142 | + dynamic_buffer.resize(full_length, 0); |
| 143 | + stream |
| 144 | + .read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length]) |
| 145 | + .await?; |
| 146 | + |
| 147 | + Ok(Some(dynamic_buffer)) |
| 148 | + } else { |
| 149 | + stream |
| 150 | + .read_exact(&mut buffer[V2_MINIMUM_LEN..full_length]) |
| 151 | + .await?; |
| 152 | + |
| 153 | + Ok(None) |
| 154 | + } |
| 155 | +} |
| 156 | + |
| 157 | +async fn read_v1_header<I>(mut stream: I, buffer: &mut [u8; READ_BUFFER_LEN]) -> Result<()> |
| 158 | +where |
| 159 | + I: AsyncRead + Unpin, |
| 160 | +{ |
| 161 | + let mut end_found = false; |
| 162 | + for i in V1_PREFIX_LEN..V1_MAX_LENGTH { |
| 163 | + buffer[i] = stream.read_u8().await?; |
| 164 | + |
| 165 | + if [buffer[i - 1], buffer[i]] == V1_TERMINATOR { |
| 166 | + end_found = true; |
| 167 | + break; |
| 168 | + } |
| 169 | + } |
| 170 | + if !end_found { |
| 171 | + bail!("No valid Proxy Protocol header detected"); |
| 172 | + } |
| 173 | + |
| 174 | + Ok(()) |
| 175 | +} |
0 commit comments