Skip to content

Commit 95fd408

Browse files
committed
gw: Implement proxy protocol
1 parent 1c58db6 commit 95fd408

File tree

10 files changed

+269
-14
lines changed

10 files changed

+269
-14
lines changed

Cargo.lock

Lines changed: 38 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,4 @@ serde_yaml2 = "0.1.2"
218218

219219
luks2 = "0.5.0"
220220
scopeguard = "1.2.0"
221+
proxy-protocol = "0.5.0"

gateway/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ reqwest = { workspace = true, features = ["json"] }
5050
hyper = { workspace = true, features = ["server", "http1"] }
5151
hyper-util = { version = "0.1", features = ["tokio"] }
5252
jemallocator.workspace = true
53+
proxy-protocol.workspace = true
5354

5455
[target.'cfg(unix)'.dependencies]
5556
nix = { workspace = true, features = ["resource"] }

gateway/gateway.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ app_address_ns_prefix = "_dstack-app-address"
6767
app_address_ns_compat = true
6868
workers = 32
6969
external_port = 443
70+
inbound_pp_enabled = false
7071

7172
[core.proxy.timeouts]
7273
# Timeout for establishing a connection to the target app.
@@ -88,6 +89,8 @@ write = "5s"
8889
shutdown = "5s"
8990
# Timeout for total connection duration.
9091
total = "5h"
92+
# Timeout for proxy protocol header
93+
pp_header = "5s"
9194

9295
[core.recycle]
9396
enabled = true

gateway/src/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ pub struct ProxyConfig {
8585
pub workers: usize,
8686
pub app_address_ns_prefix: String,
8787
pub app_address_ns_compat: bool,
88+
pub inbound_pp_enabled: bool,
8889
}
8990

9091
#[derive(Debug, Clone, Deserialize)]
@@ -106,6 +107,8 @@ pub struct Timeouts {
106107
pub write: Duration,
107108
#[serde(with = "serde_duration")]
108109
pub shutdown: Duration,
110+
#[serde(with = "serde_duration")]
111+
pub pp_header: Duration,
109112
}
110113

111114
#[derive(Debug, Clone, Deserialize, Serialize)]

gateway/src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mod admin_service;
2121
mod config;
2222
mod main_service;
2323
mod models;
24+
mod pp;
2425
mod proxy;
2526
mod web_routes;
2627

gateway/src/pp.rs

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
// SPDX-FileCopyrightText: © 2024-2025 Phala Network <[email protected]>
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
use std::net::SocketAddr;
6+
7+
use anyhow::{bail, Context, Result};
8+
use proxy_protocol::{version1 as v1, version2 as v2, ProxyHeader};
9+
use tokio::{
10+
io::{AsyncRead, AsyncReadExt},
11+
net::TcpStream,
12+
};
13+
14+
use crate::config::ProxyConfig;
15+
16+
const V1_PROTOCOL_PREFIX: &str = "PROXY";
17+
const V1_PREFIX_LEN: usize = 5;
18+
const V1_MAX_LENGTH: usize = 107;
19+
const V1_TERMINATOR: &[u8] = b"\r\n";
20+
21+
const V2_PROTOCOL_PREFIX: &[u8] = b"\r\n\r\n\0\r\nQUIT\n";
22+
const V2_PREFIX_LEN: usize = 12;
23+
const V2_MINIMUM_LEN: usize = 16;
24+
const V2_LENGTH_INDEX: usize = 14;
25+
const READ_BUFFER_LEN: usize = 512;
26+
const V2_MAX_LENGTH: usize = 2048;
27+
28+
pub(crate) async fn get_inbound_pp_header(
29+
inbound: TcpStream,
30+
config: &ProxyConfig,
31+
) -> Result<(TcpStream, ProxyHeader)> {
32+
if config.inbound_pp_enabled {
33+
read_proxy_header(inbound).await
34+
} else {
35+
let header = create_inbound_pp_header(&inbound);
36+
Ok((inbound, header))
37+
}
38+
}
39+
40+
pub struct DisplayAddr<'a>(pub &'a ProxyHeader);
41+
42+
impl std::fmt::Display for DisplayAddr<'_> {
43+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44+
match self.0 {
45+
ProxyHeader::Version2 { addresses, .. } => match addresses {
46+
v2::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source),
47+
v2::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source),
48+
v2::ProxyAddresses::Unix { .. } => write!(f, "<unix>"),
49+
v2::ProxyAddresses::Unspec => write!(f, "<unspec>"),
50+
},
51+
ProxyHeader::Version1 { addresses, .. } => match addresses {
52+
v1::ProxyAddresses::Ipv4 { source, .. } => write!(f, "{}", source),
53+
v1::ProxyAddresses::Ipv6 { source, .. } => write!(f, "{}", source),
54+
v1::ProxyAddresses::Unknown => write!(f, "<unknown>"),
55+
},
56+
_ => write!(f, "<unknown ver>"),
57+
}
58+
}
59+
}
60+
61+
fn create_inbound_pp_header(inbound: &TcpStream) -> ProxyHeader {
62+
// When PROXY protocol is disabled, create a synthetic header from the actual TCP connection
63+
let peer_addr = inbound.peer_addr().ok();
64+
let local_addr = inbound.local_addr().ok();
65+
66+
match (peer_addr, local_addr) {
67+
(Some(SocketAddr::V4(source)), Some(SocketAddr::V4(destination))) => {
68+
ProxyHeader::Version2 {
69+
command: v2::ProxyCommand::Proxy,
70+
transport_protocol: v2::ProxyTransportProtocol::Stream,
71+
addresses: v2::ProxyAddresses::Ipv4 {
72+
source,
73+
destination,
74+
},
75+
}
76+
}
77+
(Some(SocketAddr::V6(source)), Some(SocketAddr::V6(destination))) => {
78+
ProxyHeader::Version2 {
79+
command: v2::ProxyCommand::Proxy,
80+
transport_protocol: v2::ProxyTransportProtocol::Stream,
81+
addresses: v2::ProxyAddresses::Ipv6 {
82+
source,
83+
destination,
84+
},
85+
}
86+
}
87+
_ => ProxyHeader::Version2 {
88+
command: v2::ProxyCommand::Proxy,
89+
transport_protocol: v2::ProxyTransportProtocol::Stream,
90+
addresses: v2::ProxyAddresses::Unspec,
91+
},
92+
}
93+
}
94+
95+
async fn read_proxy_header<I>(mut stream: I) -> Result<(I, ProxyHeader)>
96+
where
97+
I: AsyncRead + Unpin,
98+
{
99+
let mut buffer = [0; READ_BUFFER_LEN];
100+
let mut dynamic_buffer = None;
101+
102+
stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?;
103+
104+
if &buffer[..V1_PREFIX_LEN] == V1_PROTOCOL_PREFIX.as_bytes() {
105+
read_v1_header(&mut stream, &mut buffer).await?;
106+
} else {
107+
stream
108+
.read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN])
109+
.await?;
110+
if &buffer[..V2_PREFIX_LEN] == V2_PROTOCOL_PREFIX {
111+
dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?;
112+
} else {
113+
bail!("No valid Proxy Protocol header detected");
114+
}
115+
}
116+
117+
let mut buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]);
118+
119+
let header =
120+
proxy_protocol::parse(&mut buffer).context("failed to parse proxy protocol header")?;
121+
Ok((stream, header))
122+
}
123+
124+
async fn read_v2_header<I>(
125+
mut stream: I,
126+
buffer: &mut [u8; READ_BUFFER_LEN],
127+
) -> Result<Option<Vec<u8>>>
128+
where
129+
I: AsyncRead + Unpin,
130+
{
131+
let length =
132+
u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize;
133+
let full_length = V2_MINIMUM_LEN + length;
134+
135+
if full_length > V2_MAX_LENGTH {
136+
bail!("V2 Proxy Protocol header is too long");
137+
}
138+
139+
if full_length > READ_BUFFER_LEN {
140+
let mut dynamic_buffer = Vec::with_capacity(full_length);
141+
dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]);
142+
dynamic_buffer.resize(full_length, 0);
143+
stream
144+
.read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length])
145+
.await?;
146+
147+
Ok(Some(dynamic_buffer))
148+
} else {
149+
stream
150+
.read_exact(&mut buffer[V2_MINIMUM_LEN..full_length])
151+
.await?;
152+
153+
Ok(None)
154+
}
155+
}
156+
157+
async fn read_v1_header<I>(mut stream: I, buffer: &mut [u8; READ_BUFFER_LEN]) -> Result<()>
158+
where
159+
I: AsyncRead + Unpin,
160+
{
161+
let mut end_found = false;
162+
for i in V1_PREFIX_LEN..V1_MAX_LENGTH {
163+
buffer[i] = stream.read_u8().await?;
164+
165+
if [buffer[i - 1], buffer[i]] == V1_TERMINATOR {
166+
end_found = true;
167+
break;
168+
}
169+
}
170+
if !end_found {
171+
bail!("No valid Proxy Protocol header detected");
172+
}
173+
174+
Ok(())
175+
}

gateway/src/proxy.rs

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ use tokio::{
2020
};
2121
use tracing::{debug, error, info, info_span, Instrument};
2222

23-
use crate::{config::ProxyConfig, main_service::Proxy, models::EnteredCounter};
23+
use crate::{
24+
config::ProxyConfig,
25+
main_service::Proxy,
26+
models::EnteredCounter,
27+
pp::{get_inbound_pp_header, DisplayAddr},
28+
};
2429

2530
#[derive(Debug, Clone)]
2631
pub(crate) struct AddressInfo {
@@ -159,11 +164,19 @@ fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result<DstInfo> {
159164
pub static NUM_CONNECTIONS: AtomicU64 = AtomicU64::new(0);
160165

161166
async fn handle_connection(
162-
mut inbound: TcpStream,
167+
inbound: TcpStream,
163168
state: Proxy,
164169
dotted_base_domain: &str,
165170
) -> Result<()> {
166171
let timeouts = &state.config.proxy.timeouts;
172+
173+
let pp_timeout = timeouts.pp_header;
174+
let pp_fut = get_inbound_pp_header(inbound, &state.config.proxy);
175+
let (mut inbound, pp_header) = timeout(pp_timeout, pp_fut)
176+
.await
177+
.context("take proxy protocol header timeout")?
178+
.context("failed to take proxy protocol header")?;
179+
info!("client address: {}", DisplayAddr(&pp_header));
167180
let (sni, buffer) = timeout(timeouts.handshake, take_sni(&mut inbound))
168181
.await
169182
.context("take sni timeout")?
@@ -175,14 +188,14 @@ async fn handle_connection(
175188
let dst = parse_destination(&sni, dotted_base_domain)?;
176189
debug!("dst: {dst:?}");
177190
if dst.is_tls {
178-
tls_passthough::proxy_to_app(state, inbound, buffer, &dst.app_id, dst.port).await
191+
tls_passthough::proxy_to_app(state, inbound, pp_header, buffer, &dst).await
179192
} else {
180193
state
181-
.proxy(inbound, buffer, &dst.app_id, dst.port, dst.is_h2)
194+
.proxy(inbound, pp_header, buffer, &dst)
182195
.await
183196
}
184197
} else {
185-
tls_passthough::proxy_with_sni(state, inbound, buffer, &sni).await
198+
tls_passthough::proxy_with_sni(state, inbound, pp_header, buffer, &sni).await
186199
}
187200
}
188201

0 commit comments

Comments
 (0)