Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,4 @@ serde_yaml2 = "0.1.2"

luks2 = "0.5.0"
scopeguard = "1.2.0"
proxy-protocol = "0.5.0"
1 change: 1 addition & 0 deletions gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ reqwest = { workspace = true, features = ["json"] }
hyper = { workspace = true, features = ["server", "http1"] }
hyper-util = { version = "0.1", features = ["tokio"] }
jemallocator.workspace = true
proxy-protocol.workspace = true

[target.'cfg(unix)'.dependencies]
nix = { workspace = true, features = ["resource"] }
Expand Down
3 changes: 3 additions & 0 deletions gateway/gateway.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ app_address_ns_prefix = "_dstack-app-address"
app_address_ns_compat = true
workers = 32
external_port = 443
inbound_pp_enabled = false

[core.proxy.timeouts]
# Timeout for establishing a connection to the target app.
Expand All @@ -88,6 +89,8 @@ write = "5s"
shutdown = "5s"
# Timeout for total connection duration.
total = "5h"
# Timeout for proxy protocol header
pp_header = "5s"

[core.recycle]
enabled = true
Expand Down
3 changes: 3 additions & 0 deletions gateway/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pub struct ProxyConfig {
pub workers: usize,
pub app_address_ns_prefix: String,
pub app_address_ns_compat: bool,
pub inbound_pp_enabled: bool,
}

#[derive(Debug, Clone, Deserialize)]
Expand All @@ -106,6 +107,8 @@ pub struct Timeouts {
pub write: Duration,
#[serde(with = "serde_duration")]
pub shutdown: Duration,
#[serde(with = "serde_duration")]
pub pp_header: Duration,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
Expand Down
1 change: 1 addition & 0 deletions gateway/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod admin_service;
mod config;
mod main_service;
mod models;
mod pp;
mod proxy;
mod web_routes;

Expand Down
175 changes: 175 additions & 0 deletions gateway/src/pp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
// SPDX-FileCopyrightText: © 2024-2025 Phala Network <[email protected]>
//
// SPDX-License-Identifier: Apache-2.0

use std::net::SocketAddr;

use anyhow::{bail, Context, Result};
use proxy_protocol::{version1 as v1, version2 as v2, ProxyHeader};
use tokio::{
io::{AsyncRead, AsyncReadExt},
net::TcpStream,
};

use crate::config::ProxyConfig;

const V1_PROTOCOL_PREFIX: &str = "PROXY";
const V1_PREFIX_LEN: usize = 5;
const V1_MAX_LENGTH: usize = 107;
const V1_TERMINATOR: &[u8] = b"\r\n";

const V2_PROTOCOL_PREFIX: &[u8] = b"\r\n\r\n\0\r\nQUIT\n";
const V2_PREFIX_LEN: usize = 12;
const V2_MINIMUM_LEN: usize = 16;
const V2_LENGTH_INDEX: usize = 14;
const READ_BUFFER_LEN: usize = 512;
const V2_MAX_LENGTH: usize = 2048;

pub(crate) async fn get_inbound_pp_header(
inbound: TcpStream,
config: &ProxyConfig,
) -> Result<(TcpStream, ProxyHeader)> {
if config.inbound_pp_enabled {
read_proxy_header(inbound).await
} else {
let header = create_inbound_pp_header(&inbound);
Ok((inbound, header))
}
}

pub struct DisplayAddr<'a>(pub &'a ProxyHeader);

impl std::fmt::Display for DisplayAddr<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0 {
ProxyHeader::Version2 { addresses, .. } => match addresses {
v2::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source),
v2::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source),
v2::ProxyAddresses::Unix { .. } => write!(f, "<unix>"),
v2::ProxyAddresses::Unspec => write!(f, "<unspec>"),
},
ProxyHeader::Version1 { addresses, .. } => match addresses {
v1::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source),
v1::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source),
v1::ProxyAddresses::Unknown => write!(f, "<unknown>"),
},
_ => write!(f, "<unknown ver>"),
}
}
}

fn create_inbound_pp_header(inbound: &TcpStream) -> ProxyHeader {
// When PROXY protocol is disabled, create a synthetic header from the actual TCP connection
let peer_addr = inbound.peer_addr().ok();
let local_addr = inbound.local_addr().ok();

match (peer_addr, local_addr) {
(Some(SocketAddr::V4(source)), Some(SocketAddr::V4(destination))) => {
ProxyHeader::Version2 {
command: v2::ProxyCommand::Proxy,
transport_protocol: v2::ProxyTransportProtocol::Stream,
addresses: v2::ProxyAddresses::Ipv4 {
source,
destination,
},
}
}
(Some(SocketAddr::V6(source)), Some(SocketAddr::V6(destination))) => {
ProxyHeader::Version2 {
command: v2::ProxyCommand::Proxy,
transport_protocol: v2::ProxyTransportProtocol::Stream,
addresses: v2::ProxyAddresses::Ipv6 {
source,
destination,
},
}
}
_ => ProxyHeader::Version2 {
command: v2::ProxyCommand::Proxy,
transport_protocol: v2::ProxyTransportProtocol::Stream,
addresses: v2::ProxyAddresses::Unspec,
},
}
}

async fn read_proxy_header<I>(mut stream: I) -> Result<(I, ProxyHeader)>
where
I: AsyncRead + Unpin,
{
let mut buffer = [0; READ_BUFFER_LEN];
let mut dynamic_buffer = None;

stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?;

if &buffer[..V1_PREFIX_LEN] == V1_PROTOCOL_PREFIX.as_bytes() {
read_v1_header(&mut stream, &mut buffer).await?;
} else {
stream
.read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN])
.await?;
if &buffer[..V2_PREFIX_LEN] == V2_PROTOCOL_PREFIX {
dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?;
} else {
bail!("No valid Proxy Protocol header detected");
}
}

let mut buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]);

let header =
proxy_protocol::parse(&mut buffer).context("failed to parse proxy protocol header")?;
Ok((stream, header))
}

async fn read_v2_header<I>(
mut stream: I,
buffer: &mut [u8; READ_BUFFER_LEN],
) -> Result<Option<Vec<u8>>>
where
I: AsyncRead + Unpin,
{
let length =
u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize;
let full_length = V2_MINIMUM_LEN + length;

if full_length > V2_MAX_LENGTH {
bail!("V2 Proxy Protocol header is too long");
}

if full_length > READ_BUFFER_LEN {
let mut dynamic_buffer = Vec::with_capacity(full_length);
dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]);
dynamic_buffer.resize(full_length, 0);
stream
.read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length])
.await?;

Ok(Some(dynamic_buffer))
} else {
stream
.read_exact(&mut buffer[V2_MINIMUM_LEN..full_length])
.await?;

Ok(None)
}
}

async fn read_v1_header<I>(mut stream: I, buffer: &mut [u8; READ_BUFFER_LEN]) -> Result<()>
where
I: AsyncRead + Unpin,
{
let mut end_found = false;
for i in V1_PREFIX_LEN..V1_MAX_LENGTH {
buffer[i] = stream.read_u8().await?;

if [buffer[i - 1], buffer[i]] == V1_TERMINATOR {
end_found = true;
break;
}
}
if !end_found {
bail!("No valid Proxy Protocol header detected");
}

Ok(())
}
Loading
Loading