Skip to content

Commit 3dce1d3

Browse files
committed
http2: support trailer headers
1 parent 9820975 commit 3dce1d3

File tree

4 files changed

+402
-18
lines changed

4 files changed

+402
-18
lines changed

httpcore/_async/http2.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,14 @@ def __init__(
7171
h2.events.ResponseReceived
7272
| h2.events.DataReceived
7373
| h2.events.StreamEnded
74-
| h2.events.StreamReset,
74+
| h2.events.StreamReset
75+
| h2.events.TrailersReceived,
7576
],
7677
] = {}
7778

79+
# Mapping from stream ID to trailing headers
80+
self._trailing_headers: dict[int, list[tuple[bytes, bytes]]] = {}
81+
7882
# Connection terminated events are stored as state since
7983
# we need to handle them for all streams.
8084
self._connection_terminated: h2.events.ConnectionTerminated | None = None
@@ -152,15 +156,22 @@ async def handle_async_request(self, request: Request) -> Response:
152156
)
153157
trace.return_value = (status, headers)
154158

159+
extensions = {
160+
"http_version": b"HTTP/2",
161+
"network_stream": self._network_stream,
162+
"stream_id": stream_id,
163+
}
164+
155165
return Response(
156166
status=status,
157167
headers=headers,
158-
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
159-
extensions={
160-
"http_version": b"HTTP/2",
161-
"network_stream": self._network_stream,
162-
"stream_id": stream_id,
163-
},
168+
content=HTTP2ConnectionByteStream(
169+
connection=self,
170+
request=request,
171+
stream_id=stream_id,
172+
extensions=extensions,
173+
),
174+
extensions=extensions,
164175
)
165176
except BaseException as exc: # noqa: PIE786
166177
with AsyncShieldCancellation():
@@ -326,7 +337,12 @@ async def _receive_response_body(
326337

327338
async def _receive_stream_event(
328339
self, request: Request, stream_id: int
329-
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
340+
) -> (
341+
h2.events.ResponseReceived
342+
| h2.events.DataReceived
343+
| h2.events.StreamEnded
344+
| h2.events.TrailersReceived
345+
):
330346
"""
331347
Return the next available event for a given stream ID.
332348
@@ -337,6 +353,13 @@ async def _receive_stream_event(
337353
event = self._events[stream_id].pop(0)
338354
if isinstance(event, h2.events.StreamReset):
339355
raise RemoteProtocolError(event)
356+
elif isinstance(event, h2.events.TrailersReceived):
357+
if event.stream_id in self._events and event.headers is not None:
358+
self._trailing_headers[event.stream_id] = []
359+
for k, v in event.headers:
360+
if not k.startswith(b":"):
361+
self._trailing_headers[event.stream_id].append((k, v))
362+
340363
return event
341364

342365
async def _receive_events(
@@ -377,6 +400,7 @@ async def _receive_events(
377400
h2.events.DataReceived,
378401
h2.events.StreamEnded,
379402
h2.events.StreamReset,
403+
h2.events.TrailersReceived,
380404
),
381405
):
382406
if event.stream_id in self._events:
@@ -409,6 +433,8 @@ async def _receive_remote_settings_change(
409433
async def _response_closed(self, stream_id: int) -> None:
410434
await self._max_streams_semaphore.release()
411435
del self._events[stream_id]
436+
if stream_id in self._trailing_headers:
437+
del self._trailing_headers[stream_id]
412438
async with self._state_lock:
413439
if self._connection_terminated and not self._events:
414440
await self.aclose()
@@ -561,12 +587,17 @@ async def __aexit__(
561587

562588
class HTTP2ConnectionByteStream:
563589
def __init__(
564-
self, connection: AsyncHTTP2Connection, request: Request, stream_id: int
590+
self,
591+
connection: AsyncHTTP2Connection,
592+
request: Request,
593+
stream_id: int,
594+
extensions: typing.MutableMapping[str, typing.Any],
565595
) -> None:
566596
self._connection = connection
567597
self._request = request
568598
self._stream_id = stream_id
569599
self._closed = False
600+
self._extensions = extensions
570601

571602
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
572603
kwargs = {"request": self._request, "stream_id": self._stream_id}
@@ -576,6 +607,11 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
576607
request=self._request, stream_id=self._stream_id
577608
):
578609
yield chunk
610+
611+
if self._stream_id in self._connection._trailing_headers:
612+
self._extensions["trailing_headers"] = (
613+
self._connection._trailing_headers[self._stream_id]
614+
)
579615
except BaseException as exc:
580616
# If we get an exception while streaming the response,
581617
# we want to close the response (and possibly the connection)

httpcore/_sync/http2.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,14 @@ def __init__(
7171
h2.events.ResponseReceived
7272
| h2.events.DataReceived
7373
| h2.events.StreamEnded
74-
| h2.events.StreamReset,
74+
| h2.events.StreamReset
75+
| h2.events.TrailersReceived,
7576
],
7677
] = {}
7778

79+
# Mapping from stream ID to trailing headers
80+
self._trailing_headers: dict[int, list[tuple[bytes, bytes]]] = {}
81+
7882
# Connection terminated events are stored as state since
7983
# we need to handle them for all streams.
8084
self._connection_terminated: h2.events.ConnectionTerminated | None = None
@@ -152,15 +156,22 @@ def handle_request(self, request: Request) -> Response:
152156
)
153157
trace.return_value = (status, headers)
154158

159+
extensions = {
160+
"http_version": b"HTTP/2",
161+
"network_stream": self._network_stream,
162+
"stream_id": stream_id,
163+
}
164+
155165
return Response(
156166
status=status,
157167
headers=headers,
158-
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
159-
extensions={
160-
"http_version": b"HTTP/2",
161-
"network_stream": self._network_stream,
162-
"stream_id": stream_id,
163-
},
168+
content=HTTP2ConnectionByteStream(
169+
connection=self,
170+
request=request,
171+
stream_id=stream_id,
172+
extensions=extensions,
173+
),
174+
extensions=extensions,
164175
)
165176
except BaseException as exc: # noqa: PIE786
166177
with ShieldCancellation():
@@ -326,7 +337,12 @@ def _receive_response_body(
326337

327338
def _receive_stream_event(
328339
self, request: Request, stream_id: int
329-
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
340+
) -> (
341+
h2.events.ResponseReceived
342+
| h2.events.DataReceived
343+
| h2.events.StreamEnded
344+
| h2.events.TrailersReceived
345+
):
330346
"""
331347
Return the next available event for a given stream ID.
332348
@@ -337,6 +353,13 @@ def _receive_stream_event(
337353
event = self._events[stream_id].pop(0)
338354
if isinstance(event, h2.events.StreamReset):
339355
raise RemoteProtocolError(event)
356+
elif isinstance(event, h2.events.TrailersReceived):
357+
if event.stream_id in self._events and event.headers is not None:
358+
self._trailing_headers[event.stream_id] = []
359+
for k, v in event.headers:
360+
if not k.startswith(b":"):
361+
self._trailing_headers[event.stream_id].append((k, v))
362+
340363
return event
341364

342365
def _receive_events(
@@ -377,6 +400,7 @@ def _receive_events(
377400
h2.events.DataReceived,
378401
h2.events.StreamEnded,
379402
h2.events.StreamReset,
403+
h2.events.TrailersReceived,
380404
),
381405
):
382406
if event.stream_id in self._events:
@@ -409,6 +433,8 @@ def _receive_remote_settings_change(
409433
def _response_closed(self, stream_id: int) -> None:
410434
self._max_streams_semaphore.release()
411435
del self._events[stream_id]
436+
if stream_id in self._trailing_headers:
437+
del self._trailing_headers[stream_id]
412438
with self._state_lock:
413439
if self._connection_terminated and not self._events:
414440
self.close()
@@ -561,12 +587,17 @@ def __exit__(
561587

562588
class HTTP2ConnectionByteStream:
563589
def __init__(
564-
self, connection: HTTP2Connection, request: Request, stream_id: int
590+
self,
591+
connection: HTTP2Connection,
592+
request: Request,
593+
stream_id: int,
594+
extensions: typing.MutableMapping[str, typing.Any],
565595
) -> None:
566596
self._connection = connection
567597
self._request = request
568598
self._stream_id = stream_id
569599
self._closed = False
600+
self._extensions = extensions
570601

571602
def __iter__(self) -> typing.Iterator[bytes]:
572603
kwargs = {"request": self._request, "stream_id": self._stream_id}
@@ -576,6 +607,11 @@ def __iter__(self) -> typing.Iterator[bytes]:
576607
request=self._request, stream_id=self._stream_id
577608
):
578609
yield chunk
610+
611+
if self._stream_id in self._connection._trailing_headers:
612+
self._extensions["trailing_headers"] = (
613+
self._connection._trailing_headers[self._stream_id]
614+
)
579615
except BaseException as exc:
580616
# If we get an exception while streaming the response,
581617
# we want to close the response (and possibly the connection)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import hpack
2+
import hyperframe.frame
3+
import pytest
4+
5+
import httpcore
6+
7+
8+
@pytest.mark.anyio
9+
async def test_http2_connection_with_trailing_headers():
10+
"""
11+
Test that trailing headers are correctly received and processed.
12+
"""
13+
origin = httpcore.Origin(b"https", b"example.com", 443)
14+
stream = httpcore.AsyncMockStream(
15+
[
16+
hyperframe.frame.SettingsFrame().serialize(),
17+
hyperframe.frame.HeadersFrame(
18+
stream_id=1,
19+
data=hpack.Encoder().encode(
20+
[
21+
(b":status", b"200"),
22+
(b"content-type", b"plain/text"),
23+
]
24+
),
25+
flags=["END_HEADERS"],
26+
).serialize(),
27+
hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, world!").serialize(),
28+
# Send trailing headers
29+
hyperframe.frame.HeadersFrame(
30+
stream_id=1,
31+
data=hpack.Encoder().encode(
32+
[
33+
(b"x-trailer-1", b"trailer-value-1"),
34+
(b"x-trailer-2", b"trailer-value-2"),
35+
]
36+
),
37+
flags=["END_HEADERS", "END_STREAM"],
38+
).serialize(),
39+
]
40+
)
41+
async with httpcore.AsyncHTTP2Connection(
42+
origin=origin, stream=stream, keepalive_expiry=5.0
43+
) as conn:
44+
response = await conn.request("GET", "https://example.com/")
45+
assert response.status == 200
46+
assert response.content == b"Hello, world!"
47+
48+
# Check that trailing headers are included in extensions
49+
assert "trailing_headers" in response.extensions
50+
assert response.extensions["trailing_headers"] == [
51+
(b"x-trailer-1", b"trailer-value-1"),
52+
(b"x-trailer-2", b"trailer-value-2"),
53+
]
54+
55+
56+
@pytest.mark.anyio
57+
async def test_http2_connection_with_body_and_trailing_headers():
58+
"""
59+
Test that trailing headers are correctly received and processed
60+
when reading the response body in chunks.
61+
"""
62+
origin = httpcore.Origin(b"https", b"example.com", 443)
63+
stream = httpcore.AsyncMockStream(
64+
[
65+
hyperframe.frame.SettingsFrame().serialize(),
66+
hyperframe.frame.HeadersFrame(
67+
stream_id=1,
68+
data=hpack.Encoder().encode(
69+
[
70+
(b":status", b"200"),
71+
(b"content-type", b"plain/text"),
72+
]
73+
),
74+
flags=["END_HEADERS"],
75+
).serialize(),
76+
hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, ").serialize(),
77+
hyperframe.frame.DataFrame(stream_id=1, data=b"world!").serialize(),
78+
# Send trailing headers
79+
hyperframe.frame.HeadersFrame(
80+
stream_id=1,
81+
data=hpack.Encoder().encode(
82+
[
83+
(b"x-trailer-1", b"trailer-value-1"),
84+
(b"x-trailer-2", b"trailer-value-2"),
85+
]
86+
),
87+
flags=["END_HEADERS", "END_STREAM"],
88+
).serialize(),
89+
]
90+
)
91+
92+
async with httpcore.AsyncHTTP2Connection(
93+
origin=origin, stream=stream, keepalive_expiry=5.0
94+
) as conn:
95+
async with conn.stream("GET", "https://example.com/") as response:
96+
content = b""
97+
async for chunk in response.aiter_stream():
98+
content += chunk
99+
100+
assert response.status == 200
101+
assert content == b"Hello, world!"
102+
103+
# Check that trailing headers are included in extensions
104+
assert "trailing_headers" in response.extensions
105+
assert response.extensions["trailing_headers"] == [
106+
(b"x-trailer-1", b"trailer-value-1"),
107+
(b"x-trailer-2", b"trailer-value-2"),
108+
]
109+
110+
111+
@pytest.mark.anyio
112+
async def test_http2_connection_with_trailing_headers_pseudo_removed():
113+
"""
114+
Test that pseudo-headers in trailing headers are correctly filtered out.
115+
"""
116+
origin = httpcore.Origin(b"https", b"example.com", 443)
117+
stream = httpcore.AsyncMockStream(
118+
[
119+
hyperframe.frame.SettingsFrame().serialize(),
120+
hyperframe.frame.HeadersFrame(
121+
stream_id=1,
122+
data=hpack.Encoder().encode(
123+
[
124+
(b":status", b"200"),
125+
(b"content-type", b"plain/text"),
126+
]
127+
),
128+
flags=["END_HEADERS"],
129+
).serialize(),
130+
hyperframe.frame.DataFrame(stream_id=1, data=b"Hello, world!").serialize(),
131+
# Send trailing headers with a pseudo-header which should be filtered out
132+
hyperframe.frame.HeadersFrame(
133+
stream_id=1,
134+
data=hpack.Encoder().encode(
135+
[
136+
(b":pseudo", b"should-be-filtered"),
137+
(b"x-trailer", b"trailer-value"),
138+
]
139+
),
140+
flags=["END_HEADERS", "END_STREAM"],
141+
).serialize(),
142+
]
143+
)
144+
async with httpcore.AsyncHTTP2Connection(
145+
origin=origin, stream=stream, keepalive_expiry=5.0
146+
) as conn:
147+
response = await conn.request("GET", "https://example.com/")
148+
assert response.status == 200
149+
assert response.content == b"Hello, world!"
150+
151+
# Check that trailing headers are included in extensions but pseudo-headers are filtered
152+
assert "trailing_headers" in response.extensions
153+
assert len(response.extensions["trailing_headers"]) == 1
154+
assert response.extensions["trailing_headers"] == [
155+
(b"x-trailer", b"trailer-value"),
156+
]

0 commit comments

Comments
 (0)