Skip to content

Commit 9ef6ac8

Browse files
Add Sec-Websocket-Extension header
1 parent 0a87dfc commit 9ef6ac8

File tree

3 files changed

+377
-24
lines changed

3 files changed

+377
-24
lines changed

src/common/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ pub use self::referer::Referer;
5757
pub use self::referrer_policy::ReferrerPolicy;
5858
pub use self::retry_after::RetryAfter;
5959
pub use self::sec_websocket_accept::SecWebsocketAccept;
60+
pub use self::sec_websocket_extensions::{
61+
WebsocketExtensionParam, SecWebsocketExtensions, WebsocketProtocolExtension,
62+
};
6063
pub use self::sec_websocket_key::SecWebsocketKey;
6164
pub use self::sec_websocket_version::SecWebsocketVersion;
6265
pub use self::server::Server;
@@ -177,6 +180,7 @@ mod referer;
177180
mod referrer_policy;
178181
mod retry_after;
179182
mod sec_websocket_accept;
183+
mod sec_websocket_extensions;
180184
mod sec_websocket_key;
181185
mod sec_websocket_version;
182186
mod server;
Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
use std::{borrow::Cow, fmt::Debug, iter::FromIterator, str::FromStr};
2+
3+
use bytes::BytesMut;
4+
use headers_core::Error;
5+
use http::HeaderValue;
6+
7+
use crate::util::{csv, TryFromValues};
8+
9+
/// The `Sec-Websocket-Extensions` header.
10+
///
11+
/// This header is used in the Websocket handshake, sent by the client to the
12+
/// server and then from the server to the client. It is a proposed and
13+
/// agreed-upon list of websocket protocol extensions to use.
14+
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
15+
pub struct SecWebsocketExtensions(Vec<WebsocketProtocolExtension>);
16+
17+
/// An extension listed in a [`SecWebsocketExtensions`] header.
18+
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
19+
pub struct WebsocketProtocolExtension {
20+
name: Cow<'static, str>,
21+
params: Vec<WebsocketExtensionParam>,
22+
}
23+
24+
/// Named parameter for an extension in a `Sec-Websocket-Extensions` header.
25+
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
26+
pub struct WebsocketExtensionParam {
27+
name: Cow<'static, str>,
28+
value: Option<String>,
29+
}
30+
31+
impl SecWebsocketExtensions {
32+
/// Constructs a new header with the provided extensions.
33+
pub fn new(extensions: impl IntoIterator<Item = WebsocketProtocolExtension>) -> Self {
34+
Self(extensions.into_iter().collect())
35+
}
36+
37+
/// Returns an iterator over the extensions in this header.
38+
pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
39+
self.into_iter()
40+
}
41+
}
42+
43+
impl WebsocketProtocolExtension {
44+
/// Constructs a new extension directive with the given name and parameters.
45+
pub fn new(
46+
name: impl Into<Cow<'static, str>>,
47+
params: impl IntoIterator<Item = WebsocketExtensionParam>,
48+
) -> Self {
49+
Self {
50+
name: name.into(),
51+
params: params.into_iter().collect(),
52+
}
53+
}
54+
55+
/// The name of this extension directive.
56+
pub fn name(&self) -> &str {
57+
&self.name
58+
}
59+
60+
/// Returns an iterator over the parameters for this extension directive.
61+
pub fn params(&self) -> impl Iterator<Item = &WebsocketExtensionParam> {
62+
self.params.iter()
63+
}
64+
}
65+
66+
impl WebsocketExtensionParam {
67+
/// Constructs a new parameter with the given name and optional value.
68+
#[inline]
69+
pub fn new(name: impl Into<Cow<'static, str>>, value: Option<String>) -> Self {
70+
Self {
71+
name: name.into(),
72+
value,
73+
}
74+
}
75+
76+
/// The name of the parameter.
77+
#[inline]
78+
pub fn name(&self) -> &str {
79+
&self.name
80+
}
81+
82+
/// The parameter value, if there is one.
83+
#[inline]
84+
pub fn value(&self) -> Option<&str> {
85+
self.value.as_deref()
86+
}
87+
}
88+
89+
impl crate::Header for SecWebsocketExtensions {
90+
fn name() -> &'static ::http::header::HeaderName {
91+
&::http::header::SEC_WEBSOCKET_EXTENSIONS
92+
}
93+
94+
fn decode<'i, I>(values: &mut I) -> Result<Self, crate::Error>
95+
where
96+
I: Iterator<Item = &'i HeaderValue>,
97+
{
98+
crate::util::TryFromValues::try_from_values(values).map(SecWebsocketExtensions)
99+
}
100+
fn encode<E: Extend<crate::HeaderValue>>(&self, values: &mut E) {
101+
values.extend(std::iter::once(to_header_value(&self.0)));
102+
}
103+
}
104+
105+
impl TryFromValues for Vec<WebsocketProtocolExtension> {
106+
fn try_from_values<'i, I>(values: &mut I) -> Result<Self, Error>
107+
where
108+
Self: Sized,
109+
I: Iterator<Item = &'i HeaderValue>,
110+
{
111+
csv::from_comma_delimited(values)
112+
}
113+
}
114+
115+
impl FromIterator<WebsocketProtocolExtension> for SecWebsocketExtensions {
116+
fn from_iter<T: IntoIterator<Item = WebsocketProtocolExtension>>(iter: T) -> Self {
117+
Self(iter.into_iter().collect())
118+
}
119+
}
120+
121+
impl IntoIterator for SecWebsocketExtensions {
122+
type Item = WebsocketProtocolExtension;
123+
124+
type IntoIter = std::vec::IntoIter<Self::Item>;
125+
126+
fn into_iter(self) -> Self::IntoIter {
127+
self.0.into_iter()
128+
}
129+
}
130+
131+
impl<'a> IntoIterator for &'a SecWebsocketExtensions {
132+
type Item = &'a WebsocketProtocolExtension;
133+
134+
type IntoIter = std::slice::Iter<'a, WebsocketProtocolExtension>;
135+
136+
fn into_iter(self) -> Self::IntoIter {
137+
self.0.iter()
138+
}
139+
}
140+
141+
impl FromStr for WebsocketProtocolExtension {
142+
type Err = Error;
143+
144+
fn from_str(s: &str) -> Result<Self, Self::Err> {
145+
let (name, tail) = s
146+
.split_once(';')
147+
.map(|(n, t)| (n, Some(t)))
148+
.unwrap_or((s, None));
149+
150+
let params = csv::from_delimited(&mut tail.into_iter(), ';')?;
151+
152+
Ok(Self {
153+
name: name.trim().to_owned().into(),
154+
params,
155+
})
156+
}
157+
}
158+
159+
impl std::fmt::Display for WebsocketProtocolExtension {
160+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161+
let Self { name, params } = self;
162+
163+
write!(f, "{name}")?;
164+
for param in params {
165+
f.write_str("; ")?;
166+
write!(f, "{param}")?;
167+
}
168+
169+
Ok(())
170+
}
171+
}
172+
173+
impl FromStr for WebsocketExtensionParam {
174+
type Err = ();
175+
176+
fn from_str(s: &str) -> Result<Self, Self::Err> {
177+
let (name, value) = s
178+
.split_once('=')
179+
.map(|(n, t)| (n, Some(t)))
180+
.unwrap_or((s, None));
181+
182+
let value = value.map(|value| value.trim().to_owned());
183+
184+
Ok(Self {
185+
name: name.trim().to_owned().into(),
186+
value,
187+
})
188+
}
189+
}
190+
191+
impl std::fmt::Display for WebsocketExtensionParam {
192+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193+
let Self { name, value } = self;
194+
195+
write!(f, "{name}")?;
196+
if let Some(value) = value {
197+
write!(f, "={value}")?;
198+
}
199+
Ok(())
200+
}
201+
}
202+
203+
impl WebsocketProtocolExtension {
204+
fn encoded_len(&self) -> usize {
205+
let Self { name, params } = self;
206+
207+
let params_len: usize = params.iter().map(|p| p.encoded_len() + 2).sum();
208+
209+
name.len() + params_len
210+
}
211+
}
212+
213+
impl WebsocketExtensionParam {
214+
fn encoded_len(&self) -> usize {
215+
let Self { name, value } = self;
216+
name.len() + value.as_ref().map(|s| s.len() + 1).unwrap_or_default()
217+
}
218+
219+
fn write_to_buffer(&self, buffer: &mut BytesMut) {
220+
let Self { name, value } = self;
221+
buffer.extend_from_slice(b"; ");
222+
buffer.extend_from_slice(name.as_bytes());
223+
224+
if let Some(value) = value {
225+
buffer.extend_from_slice(b"=");
226+
buffer.extend_from_slice(value.as_bytes());
227+
}
228+
}
229+
}
230+
231+
fn to_header_value(extensions: &[WebsocketProtocolExtension]) -> HeaderValue {
232+
let mut buffer = BytesMut::with_capacity(encoded_len(extensions));
233+
234+
for extension in extensions {
235+
if !buffer.is_empty() {
236+
buffer.extend_from_slice(b", ");
237+
}
238+
239+
let WebsocketProtocolExtension { name, params } = extension;
240+
buffer.extend_from_slice(name.as_bytes());
241+
242+
for param in params {
243+
param.write_to_buffer(&mut buffer);
244+
}
245+
}
246+
247+
HeaderValue::from_maybe_shared(buffer).expect("valid construction")
248+
}
249+
250+
fn encoded_len(extensions: &[WebsocketProtocolExtension]) -> usize {
251+
let all_encoded_len: usize = extensions
252+
.iter()
253+
.map(WebsocketProtocolExtension::encoded_len)
254+
.sum();
255+
let all_separators_len = extensions
256+
.len()
257+
.checked_sub(1)
258+
.map(|num_separators| num_separators * 2)
259+
.unwrap_or_default();
260+
all_encoded_len + all_separators_len
261+
}
262+
263+
#[cfg(test)]
264+
mod tests {
265+
use std::convert::TryInto;
266+
267+
use crate::Header;
268+
269+
use super::super::{test_decode, test_encode};
270+
use super::*;
271+
272+
#[test]
273+
fn parse_separate_headers() {
274+
// From https://tools.ietf.org/html/rfc6455#section-9.1
275+
let extensions =
276+
test_decode::<SecWebsocketExtensions>(&["foo", "bar; baz=2"]).expect("valid");
277+
278+
assert_eq!(
279+
extensions,
280+
SecWebsocketExtensions(vec![
281+
WebsocketProtocolExtension {
282+
name: "foo".into(),
283+
params: vec![],
284+
},
285+
WebsocketProtocolExtension {
286+
name: "bar".into(),
287+
params: vec![WebsocketExtensionParam {
288+
name: "baz".into(),
289+
value: Some("2".to_owned())
290+
}],
291+
}
292+
])
293+
);
294+
}
295+
296+
#[test]
297+
fn round_trip_complex() {
298+
let extensions = test_decode::<SecWebsocketExtensions>(&[
299+
"deflate-stream",
300+
"mux; max-channels=4; flow-control, deflate-stream",
301+
"private-extension",
302+
])
303+
.expect("valid");
304+
305+
let headers = test_encode(extensions);
306+
assert_eq!(
307+
headers["sec-websocket-extensions"],
308+
"deflate-stream, mux; max-channels=4; flow-control, deflate-stream, private-extension"
309+
);
310+
}
311+
312+
#[test]
313+
fn to_header_value_exact() {
314+
// This isn't a required property for correctness but if the length
315+
// precomputation is wrong we'll over- or under-allocate during
316+
// conversion.
317+
let cases = [
318+
SecWebsocketExtensions::new([
319+
WebsocketProtocolExtension::from_str("extension-name").unwrap(),
320+
WebsocketProtocolExtension::from_str("with-params; a=5; b=8").unwrap(),
321+
]),
322+
SecWebsocketExtensions::new([]),
323+
SecWebsocketExtensions::new([
324+
WebsocketProtocolExtension::from_str("duplicate-name").unwrap(),
325+
WebsocketProtocolExtension::from_str("duplicate-name").unwrap(),
326+
WebsocketProtocolExtension::from_str("duplicate-name").unwrap(),
327+
]),
328+
];
329+
330+
for case in cases {
331+
let mut values = Vec::new();
332+
case.encode(&mut values);
333+
let [value] = values.try_into().unwrap();
334+
335+
assert_eq!(value.len(), encoded_len(&case.0));
336+
}
337+
}
338+
}

0 commit comments

Comments
 (0)