Skip to content

Adds typing to context where important for new multi-value headers, and streamlines header setting & getting #421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8606591
Make context a dataclass with typing
bitterpanda63 Jul 2, 2025
5e80426
Revert "Make context a dataclass with typing"
bitterpanda63 Jul 2, 2025
340efba
Context: add typing
bitterpanda63 Jul 2, 2025
bc2b134
Create get_header function on context
bitterpanda63 Jul 2, 2025
a784679
asgi: now returns ASGIContext
bitterpanda63 Jul 2, 2025
f9fe2ec
WSGI: now returns WSGIContext
bitterpanda63 Jul 2, 2025
e8d972e
Add typing to parse_cookies
bitterpanda63 Jul 2, 2025
516364e
Refactors context initialization for WSGI/ASGI
bitterpanda63 Jul 2, 2025
1834a2b
Fix flask test cases
bitterpanda63 Jul 2, 2025
94ff879
get_ip_from_request: make capable of using headers
bitterpanda63 Jul 2, 2025
cd9f79a
Create new headers store helper class
bitterpanda63 Jul 2, 2025
0331e85
ASGI use new headers helper class
bitterpanda63 Jul 2, 2025
8b6e0c1
WSGI use new headers helper class
bitterpanda63 Jul 2, 2025
6f30d95
context: use new headers helper class
bitterpanda63 Jul 2, 2025
3a62e36
get_ip_from_request use new Headers class
bitterpanda63 Jul 2, 2025
99129db
asgi also return headers type
bitterpanda63 Jul 2, 2025
01e6c55
Fix test cases for get_ip_form_request_test.py
bitterpanda63 Jul 2, 2025
aa31af9
rename get_ip_form_request to get_ip_from_request
bitterpanda63 Jul 2, 2025
57662c2
Fix import in headers.py
bitterpanda63 Jul 2, 2025
1e68e55
request_handler test cases now use new Headers class
bitterpanda63 Jul 2, 2025
5ec08e4
fix on_detected_attack test cases by storing value in array
bitterpanda63 Jul 2, 2025
d375bc2
Merge branch 'main' into create-get-header-extract-wsgi-multiple-head…
bitterpanda63 Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ def test_on_detected_attack_request_data_and_attack_data(
assert request_data["ipAddress"] == "198.51.100.23"
assert request_data["body"] == 123
assert request_data["headers"] == {
"CONTENT_TYPE": "application/json",
"USER_AGENT": "Mozilla/5.0",
"COOKIE": "sessionId=abc123xyz456;",
"HEADER_1": "header 1 value",
"HEADER_2": "Header 2 value",
"HOST": "localhost:8080",
"CONTENT_TYPE": ["application/json"],
"USER_AGENT": ["Mozilla/5.0"],
"COOKIE": ["sessionId=abc123xyz456;"],
"HEADER_1": ["header 1 value"],
"HEADER_2": ["Header 2 value"],
"HOST": ["localhost:8080"],
}
assert request_data["source"] == "django"
assert request_data["route"] == "/hello"
Expand Down
60 changes: 35 additions & 25 deletions aikido_zen/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,29 @@

import contextvars
import json
from json import JSONDecodeError
from time import sleep
from urllib.parse import parse_qs
from typing import Optional, Dict, List

from aikido_zen.helpers.build_route_from_url import build_route_from_url
from aikido_zen.helpers.get_subdomains_from_url import get_subdomains_from_url
from aikido_zen.helpers.logging import logger
from .wsgi import set_wsgi_attributes_on_context
from .asgi import set_asgi_attributes_on_context
from .wsgi import parse_wsgi_environ, WSGIContext
from .asgi import parse_asgi_scope, ASGIContext
from .extract_route_params import extract_route_params
from ..helpers.headers import Headers

UINPUT_SOURCES = ["body", "cookies", "query", "headers", "xml", "route_params"]
current_context = contextvars.ContextVar("current_context", default=None)
current_context = contextvars.ContextVar[Optional["Context"]](
"current_context", default=None
)

WSGI_SOURCES = ["django", "flask"]
ASGI_SOURCES = ["quart", "django_async", "starlette"]


def get_current_context():
"""Returns the current context"""
def get_current_context() -> Optional["Context"]:
try:
return current_context.get()
except Exception:
except LookupError:

Check warning on line 29 in aikido_zen/context/__init__.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/context/__init__.py#L29

Added line #L29 was not covered by tests
return None


Expand All @@ -41,30 +41,44 @@
logger.debug("Creating Context instance based on dict object.")
self.__dict__.update(context_obj)
return
# Define emtpy variables/Properties :

self.source = source
self.user = None
self.method = None
self.remote_address = None
self.url = None
self.parsed_userinput = {}
self.xml = {}
self.outgoing_req_redirects = []
self.headers: Headers = Headers()
self.query: Dict[str, List[str]] = dict()
self.cookies: Dict[str, List[str]] = dict()
self.executed_middleware = False
self.set_body(body)

# Parse WSGI/ASGI/... request :
self.cookies = self.method = self.remote_address = self.query = self.headers = (
self.url
) = None
if source in WSGI_SOURCES:
set_wsgi_attributes_on_context(self, req)
wsgi_context: WSGIContext = parse_wsgi_environ(req)
self.method = wsgi_context.method
self.remote_address = wsgi_context.remote_address
self.url = wsgi_context.url
self.headers = wsgi_context.headers
self.query = wsgi_context.query
self.cookies = wsgi_context.cookies
elif source in ASGI_SOURCES:
set_asgi_attributes_on_context(self, req)
asgi_context: ASGIContext = parse_asgi_scope(req)
self.method = asgi_context.method
self.remote_address = asgi_context.remote_address
self.url = asgi_context.url
self.headers = asgi_context.headers
self.query = asgi_context.query
self.cookies = asgi_context.cookies
else:
raise Exception("Unsupported source: " + source)

Check warning on line 76 in aikido_zen/context/__init__.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/context/__init__.py#L76

Added line #L76 was not covered by tests

# Define variables using parsed request :
self.route = build_route_from_url(self.url)
self.route_params = extract_route_params(self.url)
self.subdomains = get_subdomains_from_url(self.url)

self.executed_middleware = False

def __reduce__(self):
return (
self.__class__,
Expand Down Expand Up @@ -127,9 +141,5 @@
"url": self.url,
}

def get_user_agent(self):
if "USER_AGENT" not in self.headers:
return None
if isinstance(self.headers["USER_AGENT"], list):
return self.headers["USER_AGENT"][-1]
return self.headers["USER_AGENT"]
def get_user_agent(self) -> Optional[str]:
return self.headers.get_header("USER_AGENT")
47 changes: 29 additions & 18 deletions aikido_zen/context/asgi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
"""Exports set_asgi_attributes_on_context"""

from dataclasses import dataclass
from typing import Dict, List
from urllib.parse import parse_qs
from aikido_zen.helpers.get_ip_from_request import get_ip_from_request
from aikido_zen.helpers.logging import logger
from ..parse_cookies import parse_cookies
from .normalize_asgi_headers import normalize_asgi_headers
from .extract_asgi_headers import extract_asgi_headers
from .build_url_from_asgi import build_url_from_asgi
from ...helpers.headers import Headers


@dataclass
class ASGIContext:
method: str
headers: Headers
cookies: dict
url: str
query: dict
remote_address: str


def set_asgi_attributes_on_context(context, scope):
def parse_asgi_scope(scope) -> ASGIContext:
"""
This extracts ASGI Scope attributes, described in :
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
logger.debug("Setting ASGI attributes")
context.method = scope["method"]
context.headers = normalize_asgi_headers(scope["headers"])
headers = extract_asgi_headers(scope["headers"])

if "COOKIE" in context.headers and len(context.headers["COOKIE"]) > 0:
# Right now just use the first Cookie header, will change later to use
# framework definition for cookies.
context.cookies = parse_cookies(context.headers["COOKIE"][0])
else:
context.cookies = {}

context.url = build_url_from_asgi(scope)
context.query = parse_qs(scope["query_string"].decode("utf-8"))
cookies = {}
if "COOKIE" in headers and headers["COOKIE"]:
cookies = parse_cookies(headers["COOKIE"][-1])

raw_ip = scope["client"][0] if scope["client"] else ""
context.remote_address = get_ip_from_request(raw_ip, context.headers)
remote_address = get_ip_from_request(raw_ip, headers)

return ASGIContext(
method=scope["method"],
headers=headers,
cookies=cookies,
url=build_url_from_asgi(scope),
query=parse_qs(scope["query_string"].decode("utf-8")),
remote_address=remote_address,
)
10 changes: 10 additions & 0 deletions aikido_zen/context/asgi/extract_asgi_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Mainly exports extract_asgi_headers"""

from aikido_zen.helpers.headers import Headers


def extract_asgi_headers(headers) -> Headers:
result = Headers()
for k, v in headers:
result.store_header(k.decode("utf-8"), v.decode("utf-8"))
return result
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import pytest
from .normalize_asgi_headers import normalize_asgi_headers
from .extract_asgi_headers import extract_asgi_headers


def test_normalize_asgi_headers_basic():
def test_extract_asgi_headers_basic():
headers = [
(b"content-type", b"text/html"),
(b"accept-encoding", b"gzip, deflate"),
]
expected = {"CONTENT_TYPE": ["text/html"], "ACCEPT_ENCODING": ["gzip, deflate"]}
assert normalize_asgi_headers(headers) == expected
assert extract_asgi_headers(headers) == expected


def test_normalize_asgi_headers_empty():
def test_extract_asgi_headers_empty():
headers = []
expected = {}
assert normalize_asgi_headers(headers) == expected
assert extract_asgi_headers(headers) == expected


def test_normalize_asgi_headers_with_special_characters():
def test_extract_asgi_headers_with_special_characters():
headers = [
(b"content-type", b"text/html"),
(b"x-custom-header", b"some_value"),
Expand All @@ -28,19 +28,19 @@ def test_normalize_asgi_headers_with_special_characters():
"X_CUSTOM_HEADER": ["some_value"],
"ACCEPT_ENCODING": ["gzip, deflate"],
}
assert normalize_asgi_headers(headers) == expected
assert extract_asgi_headers(headers) == expected


def test_normalize_asgi_headers_with_dashes():
def test_extract_asgi_headers_with_dashes():
headers = [
(b"X-Forwarded-For", b"192.168.1.1"),
(b"X-Request-ID", b"abc123"),
]
expected = {"X_FORWARDED_FOR": ["192.168.1.1"], "X_REQUEST_ID": ["abc123"]}
assert normalize_asgi_headers(headers) == expected
assert extract_asgi_headers(headers) == expected


def test_normalize_asgi_headers_case_insensitivity():
def test_extract_asgi_headers_case_insensitivity():
headers = [
(b"Content-Type", b"text/html"),
(b"ACCEPT-ENCODING", b"gzip, deflate"),
Expand All @@ -50,19 +50,19 @@ def test_normalize_asgi_headers_case_insensitivity():
"CONTENT_TYPE": ["text/html"],
"ACCEPT_ENCODING": ["gzip, deflate", "json"],
}
assert normalize_asgi_headers(headers) == expected
assert extract_asgi_headers(headers) == expected


def test_normalize_asgi_headers_unicode():
def test_extract_asgi_headers_unicode():
headers = [
(b"content-type", b"text/html"),
(b"custom-header", b"value"),
]
expected = {"CONTENT_TYPE": ["text/html"], "CUSTOM_HEADER": ["value"]}
assert normalize_asgi_headers(headers) == expected
assert extract_asgi_headers(headers) == expected


def test_normalize_asgi_headers_mixed_case_and_dashes():
def test_extract_asgi_headers_mixed_case_and_dashes():
headers = [
(b"Content-Type", b"text/html"),
(b"X-Custom-Header", b"some_value"),
Expand All @@ -73,10 +73,10 @@ def test_normalize_asgi_headers_mixed_case_and_dashes():
"X_CUSTOM_HEADER": ["some_value"],
"ACCEPT_ENCODING": ["gzip, deflate"],
}
assert normalize_asgi_headers(headers) == expected
assert extract_asgi_headers(headers) == expected


def test_normalize_asgi_headers_non_ascii():
def test_extract_asgi_headers_non_ascii():
headers = [
(b"content-type", b"text/html"),
(b"custom-header", b"value"),
Expand All @@ -87,13 +87,13 @@ def test_normalize_asgi_headers_non_ascii():
"CUSTOM_HEADER": ["value"],
"X_HEADER_WITH_EMOJI": ["test"],
}
assert normalize_asgi_headers(headers) == expected
assert extract_asgi_headers(headers) == expected


def test_normalize_asgi_headers_large_input():
def test_extract_asgi_headers_large_input():
headers = [
(f"header-{i}".encode("utf-8"), f"value-{i}".encode("utf-8"))
for i in range(1000)
]
expected = {f"HEADER_{i}": [f"value-{i}"] for i in range(1000)}
assert normalize_asgi_headers(headers) == expected
assert extract_asgi_headers(headers) == expected
Loading
Loading