From 86065917cf1e30113451bf7a39cadac482823353 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 13:32:42 +0200 Subject: [PATCH 01/21] Make context a dataclass with typing --- aikido_zen/context/__init__.py | 104 +++++++++------------------------ 1 file changed, 29 insertions(+), 75 deletions(-) diff --git a/aikido_zen/context/__init__.py b/aikido_zen/context/__init__.py index af541ecd..20d4340b 100644 --- a/aikido_zen/context/__init__.py +++ b/aikido_zen/context/__init__.py @@ -1,13 +1,7 @@ -""" -Provides all the functionality for contexts -""" - +from dataclasses import dataclass, field import contextvars import json -from json import JSONDecodeError -from time import sleep -from urllib.parse import parse_qs - +from typing import Any, Dict, List, Optional, Union from aikido_zen.helpers.build_route_from_url import build_route_from_url from aikido_zen.helpers.get_subdomains_from_url import get_subdomains_from_url from aikido_zen.helpers.logging import logger @@ -18,10 +12,6 @@ UINPUT_SOURCES = ["body", "cookies", "query", "headers", "xml", "route_params"] current_context = contextvars.ContextVar("current_context", default=None) -WSGI_SOURCES = ["django", "flask"] -ASGI_SOURCES = ["quart", "django_async", "starlette"] - - def get_current_context(): """Returns the current context""" try: @@ -30,74 +20,42 @@ def get_current_context(): return None -class Context: - """ - A context object, it stores everything that is important - for vulnerability detection - """ +@dataclass +class AikidoContext: + method: str + url: str + remote_address: str + source: Optional[str] = None + user: Optional[Any] = None + executed_middleware: bool = False - def __init__(self, context_obj=None, body=None, req=None, source=None): - if context_obj: - logger.debug("Creating Context instance based on dict object.") - self.__dict__.update(context_obj) - return - # Define emtpy variables/Properties : - self.source = source - self.user = None - self.parsed_userinput = {} - self.xml = {} - self.outgoing_req_redirects = [] - self.set_body(body) + body: Optional[Any] = None + cookies: Dict[str, List[str]] = field(default_factory=dict) + query: Dict[str, List[str]] = field(default_factory=dict) + headers: Dict[str, List[str]] = field(default_factory=dict) + xml: Dict[str, Any] = field(default_factory=dict) - # Parse WSGI/ASGI/... request : - self.cookies = self.method = self.remote_address = self.query = self.headers = ( - self.url - ) = None - if source in WSGI_SOURCES: - set_wsgi_attributes_on_context(self, req) - elif source in ASGI_SOURCES: - set_asgi_attributes_on_context(self, req) + parsed_userinput: Dict[str, Any] = field(default_factory=dict) + outgoing_req_redirects: List[Any] = field(default_factory=list) - # Define variables using parsed request : + route: Optional[str] = None + route_params: Dict[str, Any] = field(default_factory=dict) + subdomains: List[str] = field(default_factory=list) + + def __post_init__(self): self.route = build_route_from_url(self.url) self.route_params = extract_route_params(self.url) self.subdomains = get_subdomains_from_url(self.url) - self.executed_middleware = False - - def __reduce__(self): - return ( - self.__class__, - ( - { - "method": self.method, - "remote_address": self.remote_address, - "url": self.url, - "body": self.body, - "headers": self.headers, - "query": self.query, - "cookies": self.cookies, - "source": self.source, - "route": self.route, - "subdomains": self.subdomains, - "user": self.user, - "xml": self.xml, - "outgoing_req_redirects": self.outgoing_req_redirects, - "executed_middleware": self.executed_middleware, - "route_params": self.route_params, - }, - None, - None, - ), - ) + def get_header(self, key: str) -> Optional[str]: + if key not in self.headers or not self.headers[key]: + return None + return self.headers[key][-1] def set_as_current_context(self): - """ - Set the current context - """ current_context.set(self) - def set_body(self, body): + def set_body(self, body: Optional[Any]): try: self.set_body_internal(body) except Exception as e: @@ -127,9 +85,5 @@ def get_route_metadata(self): "url": self.url, } - def get_user_agent(self): - if "USER_AGENT" not in self.headers: - return None - if isinstance(self.headers["USER_AGENT"], list): - return self.headers["USER_AGENT"][-1] - return self.headers["USER_AGENT"] + def get_user_agent(self) -> Optional[str]: + return self.get_header("USER_AGENT") From 5e8042668dcfa4b2c204fbf54e11632c0cd39c8f Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 13:56:06 +0200 Subject: [PATCH 02/21] Revert "Make context a dataclass with typing" This reverts commit 86065917cf1e30113451bf7a39cadac482823353. --- aikido_zen/context/__init__.py | 104 ++++++++++++++++++++++++--------- 1 file changed, 75 insertions(+), 29 deletions(-) diff --git a/aikido_zen/context/__init__.py b/aikido_zen/context/__init__.py index 20d4340b..af541ecd 100644 --- a/aikido_zen/context/__init__.py +++ b/aikido_zen/context/__init__.py @@ -1,7 +1,13 @@ -from dataclasses import dataclass, field +""" +Provides all the functionality for contexts +""" + import contextvars import json -from typing import Any, Dict, List, Optional, Union +from json import JSONDecodeError +from time import sleep +from urllib.parse import parse_qs + from aikido_zen.helpers.build_route_from_url import build_route_from_url from aikido_zen.helpers.get_subdomains_from_url import get_subdomains_from_url from aikido_zen.helpers.logging import logger @@ -12,6 +18,10 @@ UINPUT_SOURCES = ["body", "cookies", "query", "headers", "xml", "route_params"] current_context = contextvars.ContextVar("current_context", default=None) +WSGI_SOURCES = ["django", "flask"] +ASGI_SOURCES = ["quart", "django_async", "starlette"] + + def get_current_context(): """Returns the current context""" try: @@ -20,42 +30,74 @@ def get_current_context(): return None -@dataclass -class AikidoContext: - method: str - url: str - remote_address: str - source: Optional[str] = None - user: Optional[Any] = None - executed_middleware: bool = False - - body: Optional[Any] = None - cookies: Dict[str, List[str]] = field(default_factory=dict) - query: Dict[str, List[str]] = field(default_factory=dict) - headers: Dict[str, List[str]] = field(default_factory=dict) - xml: Dict[str, Any] = field(default_factory=dict) +class Context: + """ + A context object, it stores everything that is important + for vulnerability detection + """ - parsed_userinput: Dict[str, Any] = field(default_factory=dict) - outgoing_req_redirects: List[Any] = field(default_factory=list) + def __init__(self, context_obj=None, body=None, req=None, source=None): + if context_obj: + logger.debug("Creating Context instance based on dict object.") + self.__dict__.update(context_obj) + return + # Define emtpy variables/Properties : + self.source = source + self.user = None + self.parsed_userinput = {} + self.xml = {} + self.outgoing_req_redirects = [] + self.set_body(body) - route: Optional[str] = None - route_params: Dict[str, Any] = field(default_factory=dict) - subdomains: List[str] = field(default_factory=list) + # Parse WSGI/ASGI/... request : + self.cookies = self.method = self.remote_address = self.query = self.headers = ( + self.url + ) = None + if source in WSGI_SOURCES: + set_wsgi_attributes_on_context(self, req) + elif source in ASGI_SOURCES: + set_asgi_attributes_on_context(self, req) - def __post_init__(self): + # Define variables using parsed request : self.route = build_route_from_url(self.url) self.route_params = extract_route_params(self.url) self.subdomains = get_subdomains_from_url(self.url) - def get_header(self, key: str) -> Optional[str]: - if key not in self.headers or not self.headers[key]: - return None - return self.headers[key][-1] + self.executed_middleware = False + + def __reduce__(self): + return ( + self.__class__, + ( + { + "method": self.method, + "remote_address": self.remote_address, + "url": self.url, + "body": self.body, + "headers": self.headers, + "query": self.query, + "cookies": self.cookies, + "source": self.source, + "route": self.route, + "subdomains": self.subdomains, + "user": self.user, + "xml": self.xml, + "outgoing_req_redirects": self.outgoing_req_redirects, + "executed_middleware": self.executed_middleware, + "route_params": self.route_params, + }, + None, + None, + ), + ) def set_as_current_context(self): + """ + Set the current context + """ current_context.set(self) - def set_body(self, body: Optional[Any]): + def set_body(self, body): try: self.set_body_internal(body) except Exception as e: @@ -85,5 +127,9 @@ def get_route_metadata(self): "url": self.url, } - def get_user_agent(self) -> Optional[str]: - return self.get_header("USER_AGENT") + def get_user_agent(self): + if "USER_AGENT" not in self.headers: + return None + if isinstance(self.headers["USER_AGENT"], list): + return self.headers["USER_AGENT"][-1] + return self.headers["USER_AGENT"] From 340efba2fa8c9c74ef6426c5b443f561af2e8ce4 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 14:03:34 +0200 Subject: [PATCH 03/21] Context: add typing --- aikido_zen/context/__init__.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/aikido_zen/context/__init__.py b/aikido_zen/context/__init__.py index af541ecd..b14cbfad 100644 --- a/aikido_zen/context/__init__.py +++ b/aikido_zen/context/__init__.py @@ -6,6 +6,7 @@ import json from json import JSONDecodeError from time import sleep +from typing import Optional, Dict, List from urllib.parse import parse_qs from aikido_zen.helpers.build_route_from_url import build_route_from_url @@ -16,17 +17,18 @@ from .extract_route_params import extract_route_params UINPUT_SOURCES = ["body", "cookies", "query", "headers", "xml", "route_params"] -current_context = contextvars.ContextVar("current_context", default=None) +current_context = contextvars.ContextVar[Optional["Context"]]( + "current_context", default=None +) WSGI_SOURCES = ["django", "flask"] ASGI_SOURCES = ["quart", "django_async", "starlette"] -def get_current_context(): - """Returns the current context""" +def get_current_context() -> Optional["Context"]: try: return current_context.get() - except Exception: + except LookupError: return None @@ -41,30 +43,30 @@ def __init__(self, context_obj=None, body=None, req=None, source=None): logger.debug("Creating Context instance based on dict object.") self.__dict__.update(context_obj) return - # Define emtpy variables/Properties : + self.source = source self.user = None + self.method = None + self.remote_address = None + self.url = None self.parsed_userinput = {} self.xml = {} self.outgoing_req_redirects = [] + self.headers: Dict[str, List[str]] = dict() + self.query: Dict[str, List[str]] = dict() + self.cookies: Dict[str, List[str]] = dict() + self.executed_middleware = False self.set_body(body) - # Parse WSGI/ASGI/... request : - self.cookies = self.method = self.remote_address = self.query = self.headers = ( - self.url - ) = None if source in WSGI_SOURCES: set_wsgi_attributes_on_context(self, req) elif source in ASGI_SOURCES: set_asgi_attributes_on_context(self, req) - # Define variables using parsed request : self.route = build_route_from_url(self.url) self.route_params = extract_route_params(self.url) self.subdomains = get_subdomains_from_url(self.url) - self.executed_middleware = False - def __reduce__(self): return ( self.__class__, From bc2b134d92599c66369b1836060c37323f0c6cc1 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 14:03:55 +0200 Subject: [PATCH 04/21] Create get_header function on context --- aikido_zen/context/__init__.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/aikido_zen/context/__init__.py b/aikido_zen/context/__init__.py index b14cbfad..cc6466af 100644 --- a/aikido_zen/context/__init__.py +++ b/aikido_zen/context/__init__.py @@ -129,9 +129,10 @@ def get_route_metadata(self): "url": self.url, } - def get_user_agent(self): - if "USER_AGENT" not in self.headers: + def get_header(self, key: str) -> Optional[str]: + if key not in self.headers or not self.headers[key]: return None - if isinstance(self.headers["USER_AGENT"], list): - return self.headers["USER_AGENT"][-1] - return self.headers["USER_AGENT"] + return self.headers[key][-1] + + def get_user_agent(self) -> Optional[str]: + return self.get_header("USER_AGENT") From a784679bb5f4cf3a220096a888e55890e8466dba Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 14:40:59 +0200 Subject: [PATCH 05/21] asgi: now returns ASGIContext --- aikido_zen/context/asgi/__init__.py | 44 +++++++++++++-------- aikido_zen/context/asgi/init_test.py | 58 +++++++++++++++++----------- 2 files changed, 63 insertions(+), 39 deletions(-) diff --git a/aikido_zen/context/asgi/__init__.py b/aikido_zen/context/asgi/__init__.py index d083ece0..1fc15f77 100644 --- a/aikido_zen/context/asgi/__init__.py +++ b/aikido_zen/context/asgi/__init__.py @@ -1,31 +1,41 @@ -"""Exports set_asgi_attributes_on_context""" - +from dataclasses import dataclass +from typing import Dict, List from urllib.parse import parse_qs from aikido_zen.helpers.get_ip_from_request import get_ip_from_request -from aikido_zen.helpers.logging import logger from ..parse_cookies import parse_cookies from .normalize_asgi_headers import normalize_asgi_headers from .build_url_from_asgi import build_url_from_asgi -def set_asgi_attributes_on_context(context, scope): +@dataclass +class ASGIContext: + method: str + headers: Dict[str, List[str]] + cookies: dict + url: str + query: dict + remote_address: str + + +def parse_asgi_scope(scope) -> ASGIContext: """ This extracts ASGI Scope attributes, described in : https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope """ - logger.debug("Setting ASGI attributes") - context.method = scope["method"] - context.headers = normalize_asgi_headers(scope["headers"]) + headers = normalize_asgi_headers(scope["headers"]) - if "COOKIE" in context.headers and len(context.headers["COOKIE"]) > 0: - # Right now just use the first Cookie header, will change later to use - # framework definition for cookies. - context.cookies = parse_cookies(context.headers["COOKIE"][0]) - else: - context.cookies = {} - - context.url = build_url_from_asgi(scope) - context.query = parse_qs(scope["query_string"].decode("utf-8")) + cookies = {} + if "COOKIE" in headers and headers["COOKIE"]: + cookies = parse_cookies(headers["COOKIE"][-1]) raw_ip = scope["client"][0] if scope["client"] else "" - context.remote_address = get_ip_from_request(raw_ip, context.headers) + remote_address = get_ip_from_request(raw_ip, headers) + + return ASGIContext( + method=scope["method"], + headers=headers, + cookies=cookies, + url=build_url_from_asgi(scope), + query=parse_qs(scope["query_string"].decode("utf-8")), + remote_address=remote_address, + ) diff --git a/aikido_zen/context/asgi/init_test.py b/aikido_zen/context/asgi/init_test.py index ad05d6fe..3e0bb032 100644 --- a/aikido_zen/context/asgi/init_test.py +++ b/aikido_zen/context/asgi/init_test.py @@ -1,16 +1,5 @@ import pytest -from aikido_zen.context.asgi import set_asgi_attributes_on_context - - -class Context: - def __init__(self): - self.headers = None - self.cookies = None - self.method = None - self.url = None - self.query = None - self.remote_address = None - +from aikido_zen.context.asgi import parse_asgi_scope # Scope 1 : TEST_ASGI_SCOPE_1 = { @@ -26,8 +15,7 @@ def __init__(self): def test_asgi_scope_1(): - context1 = Context() - set_asgi_attributes_on_context(context1, TEST_ASGI_SCOPE_1) + context1 = parse_asgi_scope(TEST_ASGI_SCOPE_1) assert context1.method == "PUT" assert context1.remote_address == "1.1.1.1" assert context1.query == {"a": ["b"], "b": ["d"]} @@ -53,8 +41,7 @@ def test_asgi_scope_1(): def test_asgi_scope_2(): - context2 = Context() - set_asgi_attributes_on_context(context2, TEST_ASGI_SCOPE_2) + context2 = parse_asgi_scope(TEST_ASGI_SCOPE_2) assert context2.method == "GET" assert context2.remote_address == "2.2.2.2" assert context2.query == {"x": ["y"], "z": ["w"]} @@ -80,8 +67,7 @@ def test_asgi_scope_2(): def test_asgi_scope_3(): - context3 = Context() - set_asgi_attributes_on_context(context3, TEST_ASGI_SCOPE_3) + context3 = parse_asgi_scope(TEST_ASGI_SCOPE_3) assert context3.method == "POST" assert context3.remote_address == "3.3.3.3" assert context3.query == {"key1": ["value1"], "key2": ["value2"]} @@ -107,8 +93,7 @@ def test_asgi_scope_3(): def test_asgi_scope_4(): - context4 = Context() - set_asgi_attributes_on_context(context4, TEST_ASGI_SCOPE_4) + context4 = parse_asgi_scope(TEST_ASGI_SCOPE_4) assert context4.method == "DELETE" assert context4.remote_address == "4.4.4.4" assert context4.query == {} @@ -137,8 +122,7 @@ def test_asgi_scope_4(): def test_asgi_scope_multiple_header_values(): - context4 = Context() - set_asgi_attributes_on_context(context4, TEST_ASGI_SCOPE_MULTIPLE_HEADER_VALUES) + context4 = parse_asgi_scope(TEST_ASGI_SCOPE_MULTIPLE_HEADER_VALUES) assert context4.method == "DELETE" assert context4.remote_address == "4.4.4.4" assert context4.query == {} @@ -147,3 +131,33 @@ def test_asgi_scope_multiple_header_values(): } assert context4.cookies == {} # No cookies in this scope assert context4.url == "https://192.168.0.4:443/resource/123" + + +# Scope 5 : +TEST_ASGI_SCOPE_5 = { + "method": "POST", + "headers": [ + (b"COOKIE", b"session=abc"), + (b"COOKIE", b"session=abc123"), + (b"header3_test-3", b"postValue"), + ], + "query_string": b"key1=value1&key2=value2", + "client": ["3.3.3.3"], + "server": ["192.168.0.3", 8080], + "scheme": "http", + "root_path": "/api", + "path": "/api/v1/resource", +} + + +def test_asgi_scope_5(): + context5 = parse_asgi_scope(TEST_ASGI_SCOPE_5) + assert context5.method == "POST" + assert context5.remote_address == "3.3.3.3" + assert context5.query == {"key1": ["value1"], "key2": ["value2"]} + assert context5.headers == { + "COOKIE": ["session=abc", "session=abc123"], + "HEADER3_TEST_3": ["postValue"], + } + assert context5.cookies == {"session": "abc123"} + assert context5.url == "http://192.168.0.3:8080/v1/resource" From f9fe2ec6db98bdf5a1a53351141c484ad0e445ed Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 14:41:17 +0200 Subject: [PATCH 06/21] WSGI: now returns WSGIContext --- aikido_zen/context/wsgi/__init__.py | 47 +++++++++++-------- .../context/wsgi/extract_wsgi_headers.py | 6 ++- .../context/wsgi/extract_wsgi_headers_test.py | 10 ++-- 3 files changed, 37 insertions(+), 26 deletions(-) diff --git a/aikido_zen/context/wsgi/__init__.py b/aikido_zen/context/wsgi/__init__.py index 6c0f5c89..d2034be9 100644 --- a/aikido_zen/context/wsgi/__init__.py +++ b/aikido_zen/context/wsgi/__init__.py @@ -1,32 +1,41 @@ -"""Exports set_wsgi_attributes_on_context""" - +from dataclasses import dataclass +from typing import Dict, List from urllib.parse import parse_qs from aikido_zen.helpers.get_ip_from_request import get_ip_from_request -from aikido_zen.helpers.logging import logger from .extract_wsgi_headers import extract_wsgi_headers from .build_url_from_wsgi import build_url_from_wsgi from ..parse_cookies import parse_cookies -def set_wsgi_attributes_on_context(context, environ): +@dataclass +class WSGIContext: + method: str + headers: Dict[str, List[str]] + cookies: dict + url: str + query: dict + remote_address: str + + +def parse_wsgi_environ(environ) -> WSGIContext: """ This extracts WSGI attributes, described in : https://peps.python.org/pep-3333/#environ-variables """ - logger.debug("Setting wsgi attributes") - - context.method = environ["REQUEST_METHOD"] - context.headers = extract_wsgi_headers(environ) - if "COOKIE" in context.headers: - context.cookies = parse_cookies(context.headers["COOKIE"]) - else: - context.cookies = {} - context.url = build_url_from_wsgi(environ) - context.query = parse_qs(environ["QUERY_STRING"]) - context.remote_address = get_ip_from_request( - environ["REMOTE_ADDR"], context.headers - ) - + headers: Dict[str, List[str]] = extract_wsgi_headers(environ) # Content type is generally not included as a header, do include this as a header to simplify : if "CONTENT_TYPE" in environ: - context.headers["CONTENT_TYPE"] = environ["CONTENT_TYPE"] + headers["CONTENT_TYPE"] = [environ["CONTENT_TYPE"]] + + cookies = {} + if "COOKIE" in headers and headers["COOKIE"]: + cookies = parse_cookies(headers["COOKIE"][-1]) + + return WSGIContext( + method=environ["REQUEST_METHOD"], + headers=headers, + cookies=cookies, + url=build_url_from_wsgi(environ), + query=parse_qs(environ["QUERY_STRING"]), + remote_address=get_ip_from_request(environ["REMOTE_ADDR"], headers), + ) diff --git a/aikido_zen/context/wsgi/extract_wsgi_headers.py b/aikido_zen/context/wsgi/extract_wsgi_headers.py index c20cc3ae..2bcd28a3 100644 --- a/aikido_zen/context/wsgi/extract_wsgi_headers.py +++ b/aikido_zen/context/wsgi/extract_wsgi_headers.py @@ -1,11 +1,13 @@ """Exports function extract_wsgi_headers""" +from typing import Dict, List -def extract_wsgi_headers(request): + +def extract_wsgi_headers(request) -> Dict[str, List[str]]: """Extracts WSGI headers which start with HTTP_ from request dict""" headers = {} for key, value in request.items(): if key.startswith("HTTP_"): # Remove the 'HTTP_' prefix and store in the headers dictionary - headers[key[5:]] = value + headers[key[5:]] = [value] return headers diff --git a/aikido_zen/context/wsgi/extract_wsgi_headers_test.py b/aikido_zen/context/wsgi/extract_wsgi_headers_test.py index 85fc478d..6336f3a7 100644 --- a/aikido_zen/context/wsgi/extract_wsgi_headers_test.py +++ b/aikido_zen/context/wsgi/extract_wsgi_headers_test.py @@ -4,7 +4,7 @@ def test_extract_wsgi_headers_single_header(): request = {"REQUEST_METHOD": "GET", "HTTP_USER_AGENT": "Mozilla/5.0"} - expected = {"USER_AGENT": "Mozilla/5.0"} + expected = {"USER_AGENT": ["Mozilla/5.0"]} assert extract_wsgi_headers(request) == expected @@ -16,9 +16,9 @@ def test_extract_wsgi_headers_multiple_headers(): "HTTP_CONTENT_TYPE": "application/json", } expected = { - "HOST": "example.com", - "ACCEPT": "text/html", - "CONTENT_TYPE": "application/json", + "HOST": ["example.com"], + "ACCEPT": ["text/html"], + "CONTENT_TYPE": ["application/json"], } assert extract_wsgi_headers(request) == expected @@ -42,5 +42,5 @@ def test_extract_wsgi_headers_mixed_headers(): "HTTP_ACCEPT_LANGUAGE": "en-US,en;q=0.5", "OTHER_HEADER": "value", } - expected = {"USER_AGENT": "Mozilla/5.0", "ACCEPT_LANGUAGE": "en-US,en;q=0.5"} + expected = {"USER_AGENT": ["Mozilla/5.0"], "ACCEPT_LANGUAGE": ["en-US,en;q=0.5"]} assert extract_wsgi_headers(request) == expected From e8d972e07ccd4716666b562d2db0a5459ccc17f3 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 14:41:33 +0200 Subject: [PATCH 07/21] Add typing to parse_cookies A --- aikido_zen/context/parse_cookies.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aikido_zen/context/parse_cookies.py b/aikido_zen/context/parse_cookies.py index ddc729ca..23771420 100644 --- a/aikido_zen/context/parse_cookies.py +++ b/aikido_zen/context/parse_cookies.py @@ -3,9 +3,10 @@ """ from http.cookies import SimpleCookie, CookieError +from typing import Dict -def parse_cookies(cookie_str): +def parse_cookies(cookie_str: str) -> Dict[str, str]: """Parse cookie string from headers""" cookie_dict = {} cookies = SimpleCookie() From 516364e4d300a3c93ec419ae5168d115f28dd791 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 14:41:43 +0200 Subject: [PATCH 08/21] Refactors context initialization for WSGI/ASGI Changes context initialization to utilize dedicated parsing functions (parse_wsgi_environ, parse_asgi_scope) for WSGI and ASGI requests. This improves code organization and maintainability by decoupling context attribute assignment from the main Context class, and promotes code reuse. Updates tests to reflect the change in header structure. --- aikido_zen/context/__init__.py | 25 +++++++++++++++++------- aikido_zen/context/init_test.py | 34 ++++++++++++++++----------------- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/aikido_zen/context/__init__.py b/aikido_zen/context/__init__.py index cc6466af..ef31a707 100644 --- a/aikido_zen/context/__init__.py +++ b/aikido_zen/context/__init__.py @@ -4,16 +4,13 @@ import contextvars import json -from json import JSONDecodeError -from time import sleep from typing import Optional, Dict, List -from urllib.parse import parse_qs from aikido_zen.helpers.build_route_from_url import build_route_from_url from aikido_zen.helpers.get_subdomains_from_url import get_subdomains_from_url from aikido_zen.helpers.logging import logger -from .wsgi import set_wsgi_attributes_on_context -from .asgi import set_asgi_attributes_on_context +from .wsgi import parse_wsgi_environ, WSGIContext +from .asgi import parse_asgi_scope, ASGIContext from .extract_route_params import extract_route_params UINPUT_SOURCES = ["body", "cookies", "query", "headers", "xml", "route_params"] @@ -59,9 +56,23 @@ def __init__(self, context_obj=None, body=None, req=None, source=None): self.set_body(body) if source in WSGI_SOURCES: - set_wsgi_attributes_on_context(self, req) + wsgi_context: WSGIContext = parse_wsgi_environ(req) + self.method = wsgi_context.method + self.remote_address = wsgi_context.remote_address + self.url = wsgi_context.url + self.headers = wsgi_context.headers + self.query = wsgi_context.query + self.cookies = wsgi_context.cookies elif source in ASGI_SOURCES: - set_asgi_attributes_on_context(self, req) + asgi_context: ASGIContext = parse_asgi_scope(req) + self.method = asgi_context.method + self.remote_address = asgi_context.remote_address + self.url = asgi_context.url + self.headers = asgi_context.headers + self.query = asgi_context.query + self.cookies = asgi_context.cookies + else: + raise Exception("Unsupported source: " + source) self.route = build_route_from_url(self.url) self.route_params = extract_route_params(self.url) diff --git a/aikido_zen/context/init_test.py b/aikido_zen/context/init_test.py index 4769167b..b9118337 100644 --- a/aikido_zen/context/init_test.py +++ b/aikido_zen/context/init_test.py @@ -52,11 +52,11 @@ def test_wsgi_context_1(): "source": "django", "method": "POST", "headers": { - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "COOKIE": "sessionId=abc123xyz456;", - "HOST": "example.com", - "CONTENT_TYPE": "application/x-www-form-urlencoded", + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "COOKIE": ["sessionId=abc123xyz456;"], + "HOST": ["example.com"], + "CONTENT_TYPE": ["application/x-www-form-urlencoded"], }, "cookies": {"sessionId": "abc123xyz456"}, "url": "https://example.com/hello", @@ -81,12 +81,12 @@ def test_wsgi_context_2(): "source": "flask", "method": "GET", "headers": { - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "COOKIE": "sessionId=abc123xyz456;", - "HOST": "localhost:8080", - "CONTENT_TYPE": "application/json", - "USER_AGENT": "Mozilla/5.0", + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "COOKIE": ["sessionId=abc123xyz456;"], + "HOST": ["localhost:8080"], + "CONTENT_TYPE": ["application/json"], + "USER_AGENT": ["Mozilla/5.0"], }, "cookies": {"sessionId": "abc123xyz456"}, "url": "http://localhost:8080/hello", @@ -129,12 +129,12 @@ def test_context_is_picklable(mocker): assert unpickled_obj.url == "http://localhost:8080/hello" assert unpickled_obj.body == 123 assert unpickled_obj.headers == { - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "COOKIE": "sessionId=abc123xyz456;", - "HOST": "localhost:8080", - "CONTENT_TYPE": "application/json", - "USER_AGENT": "Mozilla/5.0", + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "COOKIE": ["sessionId=abc123xyz456;"], + "HOST": ["localhost:8080"], + "CONTENT_TYPE": ["application/json"], + "USER_AGENT": ["Mozilla/5.0"], } assert unpickled_obj.query == {"user": ["JohnDoe"], "age": ["30", "35"]} assert unpickled_obj.cookies == {"sessionId": "abc123xyz456"} From 1834a2bfb8e86218f5b85eed0772a7878172ec2c Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 14:52:06 +0200 Subject: [PATCH 09/21] Fix flask test cases --- aikido_zen/sources/flask/flask_test.py | 30 +++++++++++++------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/aikido_zen/sources/flask/flask_test.py b/aikido_zen/sources/flask/flask_test.py index 4a1d448a..9d25f765 100644 --- a/aikido_zen/sources/flask/flask_test.py +++ b/aikido_zen/sources/flask/flask_test.py @@ -104,11 +104,11 @@ def hello(user, age): assert get_current_context().method == "POST" assert get_current_context().body is None assert get_current_context().headers == { - "COOKIE": "sessionId=abc123xyz456;", - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "HOST": "example.com", - "CONTENT_TYPE": "application/json", + "COOKIE": ["sessionId=abc123xyz456;"], + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "HOST": ["example.com"], + "CONTENT_TYPE": ["application/json"], } calls = mock_request_handler.call_args_list assert len(calls) == 2 @@ -175,11 +175,11 @@ def test_flask_all_3_func_with_invalid_body(): get_current_context().body == None ) # body is None since it's invalid json assert get_current_context().headers == { - "COOKIE": "sessionId=abc123xyz456;", - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "HOST": "example.com", - "CONTENT_TYPE": "application/json", + "COOKIE": ["sessionId=abc123xyz456;"], + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "HOST": ["example.com"], + "CONTENT_TYPE": ["application/json"], } calls = mock_request_handler.call_args_list assert len(calls) == 2 @@ -208,11 +208,11 @@ def test_flask_all_3_func(): assert get_current_context().method == "POST" assert get_current_context().body == None assert get_current_context().headers == { - "COOKIE": "sessionId=abc123xyz456;", - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "HOST": "example.com", - "CONTENT_TYPE": "application/x-www-form-urlencoded", + "COOKIE": ["sessionId=abc123xyz456;"], + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "HOST": ["example.com"], + "CONTENT_TYPE": ["application/x-www-form-urlencoded"], } calls = mock_request_handler.call_args_list assert len(calls) == 2 From 94ff879b1bbdb4a8430abba8d271d8c390feaa2c Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 15:03:05 +0200 Subject: [PATCH 10/21] get_ip_from_request: make capable of using headers --- aikido_zen/helpers/get_ip_form_request_test.py | 8 ++++---- aikido_zen/helpers/get_ip_from_request.py | 13 +++++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/aikido_zen/helpers/get_ip_form_request_test.py b/aikido_zen/helpers/get_ip_form_request_test.py index 6935ea89..e357552b 100644 --- a/aikido_zen/helpers/get_ip_form_request_test.py +++ b/aikido_zen/helpers/get_ip_form_request_test.py @@ -9,11 +9,11 @@ # Test `get_ip_from_request` function : def test_get_ip_from_request(): # Test case 1: Valid X_FORWARDED_FOR header with valid IP - headers = {"X_FORWARDED_FOR": "192.168.1.1, 10.0.0.1"} + headers = {"X_FORWARDED_FOR": ["192.168.1.1, 10.0.0.1"]} assert get_ip_from_request(None, headers) == "192.168.1.1" # Test case 2: Valid X_FORWARDED_FOR header with invalid IPs - headers = {"X_FORWARDED_FOR": "256.256.256.256, 192.168.1.1"} + headers = {"X_FORWARDED_FOR": ["256.256.256.256, 192.168.1.1"]} assert ( get_ip_from_request(None, headers) == "192.168.1.1" ) # Should return the valid IP @@ -23,13 +23,13 @@ def test_get_ip_from_request(): assert get_ip_from_request("10.0.0.1", headers) == "10.0.0.1" # Test case 4: Valid remote address with invalid X_FORWARDED_FOR - headers = {"X_FORWARDED_FOR": "abc.def.ghi.jkl, 256.256.256.256"} + headers = {"X_FORWARDED_FOR": ["abc.def.ghi.jkl, 256.256.256.256"]} assert ( get_ip_from_request("10.0.0.1", headers) == "10.0.0.1" ) # Should return the remote address # Test case 5: Both X_FORWARDED_FOR and remote address are invalid - headers = {"X_FORWARDED_FOR": "abc.def.ghi.jkl, 256.256.256.256"} + headers = {"X_FORWARDED_FOR": ["abc.def.ghi.jkl, 256.256.256.256"]} assert get_ip_from_request(None, headers) is None # Should return None # Test case 6: Empty headers and remote address diff --git a/aikido_zen/helpers/get_ip_from_request.py b/aikido_zen/helpers/get_ip_from_request.py index e67a0a30..567ef1e6 100644 --- a/aikido_zen/helpers/get_ip_from_request.py +++ b/aikido_zen/helpers/get_ip_from_request.py @@ -4,18 +4,19 @@ import socket import os +from typing import Dict, List, Optional + from aikido_zen.helpers.logging import logger -def get_ip_from_request(remote_address, headers): +def get_ip_from_request(remote_address, headers: Dict[str, List[str]]) -> Optional[str]: """ Tries and get the IP address from the request, checking for x-forwarded-for """ - if headers: - lower_headers = {key.lower(): value for key, value in headers.items()} - if "x_forwarded_for" in lower_headers and trust_proxy(): + if headers and "X_FORWARDED_FOR" in headers and headers["X_FORWARDED_FOR"]: + if trust_proxy(): x_forwarded_for = get_client_ip_from_x_forwarded_for( - lower_headers["x_forwarded_for"] + headers["X_FORWARDED_FOR"][-1] ) if x_forwarded_for and is_ip(x_forwarded_for): @@ -27,7 +28,7 @@ def get_ip_from_request(remote_address, headers): return None -def get_client_ip_from_x_forwarded_for(value): +def get_client_ip_from_x_forwarded_for(value: str) -> Optional[str]: """ Fetches the IP out of the X-Forwarder-For headers """ From cd9f79a83a521c0c8ceefe7281bfa5244ded91c9 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 15:18:57 +0200 Subject: [PATCH 11/21] Create new headers store helper class --- aikido_zen/helpers/headers.py | 35 +++++++++++++++ aikido_zen/helpers/headers_test.py | 72 ++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 aikido_zen/helpers/headers.py create mode 100644 aikido_zen/helpers/headers_test.py diff --git a/aikido_zen/helpers/headers.py b/aikido_zen/helpers/headers.py new file mode 100644 index 00000000..c384d62e --- /dev/null +++ b/aikido_zen/helpers/headers.py @@ -0,0 +1,35 @@ +from typing import List + +from typing_extensions import Optional + + +class Headers(dict): + def store_headers(self, key: str, values: List[str]): + normalized_key = self.normalize_header_key(key) + if self.get(normalized_key, []): + self[normalized_key] += values + else: + self[normalized_key] = values + + def store_header(self, key: str, value: str): + self.store_headers(key, [value]) + + def get_header(self, key: str) -> Optional[str]: + self.validate_header_key(key) + if self.get(key, []): + return self.get(key)[-1] + return None + + @staticmethod + def validate_header_key(key: str): + if not key.isupper(): + raise ValueError("Header key must be uppercase.") + if "-" in key: + raise ValueError("Header key must use underscores instead of dashes.") + + @staticmethod + def normalize_header_key(key: str) -> str: + result = str(key) + result = result.replace("-", "_") + result = result.upper() + return result diff --git a/aikido_zen/helpers/headers_test.py b/aikido_zen/helpers/headers_test.py new file mode 100644 index 00000000..69b2dcd9 --- /dev/null +++ b/aikido_zen/helpers/headers_test.py @@ -0,0 +1,72 @@ +import pytest +from .headers import Headers + + +def test_store_header_single_value(): + headers = Headers() + headers.store_header("CONTENT_TYPE", "application/json") + assert headers["CONTENT_TYPE"] == ["application/json"] + + +def test_store_header_multiple_values(): + headers = Headers() + headers.store_headers("CONTENT_TYPE", ["application/json"]) + headers.store_headers("CONTENT_TYPE", ["text/html"]) + assert headers["CONTENT_TYPE"] == ["application/json", "text/html"] + + +def test_get_header_existing_key(): + headers = Headers() + headers.store_header("CONTENT_TYPE", "application/json") + assert headers.get_header("CONTENT_TYPE") == "application/json" + + +def test_get_header_non_existing_key(): + headers = Headers() + assert headers.get_header("NON_EXISTING_KEY") is None + + +def test_get_header_multiple_values(): + headers = Headers() + headers.store_headers("CONTENT_TYPE", ["application/json", "text/html"]) + assert headers.get_header("CONTENT_TYPE") == "text/html" + + +def test_validate_header_key_valid(): + try: + Headers.validate_header_key("VALID_HEADER") + except ValueError: + pytest.fail("validate_header_key raised ValueError unexpectedly!") + + +def test_validate_header_key_not_uppercase(): + with pytest.raises(ValueError, match="Header key must be uppercase."): + Headers.validate_header_key("invalid_header") + + +def test_validate_header_key_with_dash(): + with pytest.raises( + ValueError, match="Header key must use underscores instead of dashes." + ): + Headers.validate_header_key("INVALID-HEADER") + + +def test_normalize_header_key_valid(): + assert Headers.normalize_header_key("valid-header") == "VALID_HEADER" + assert Headers.normalize_header_key("ANOTHER-HEADER") == "ANOTHER_HEADER" + + +def test_normalize_header_key_already_normalized(): + assert Headers.normalize_header_key("ALREADY_NORMALIZED") == "ALREADY_NORMALIZED" + + +def test_store_headers_with_empty_list(): + headers = Headers() + headers.store_headers("CONTENT_TYPE", []) + assert headers.get_header("CONTENT_TYPE") is None + + +def test_store_header_with_empty_string(): + headers = Headers() + headers.store_header("CONTENT_TYPE", "") + assert headers.get_header("CONTENT_TYPE") == "" From 0331e85c979702cda7db113fac3553dbcc5f2104 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 15:19:12 +0200 Subject: [PATCH 12/21] ASGI use new headers helper class --- aikido_zen/context/asgi/__init__.py | 4 +- .../context/asgi/extract_asgi_headers.py | 10 +++++ ...s_test.py => extract_asgi_headers_test.py} | 38 +++++++++---------- .../context/asgi/normalize_asgi_headers.py | 17 --------- 4 files changed, 31 insertions(+), 38 deletions(-) create mode 100644 aikido_zen/context/asgi/extract_asgi_headers.py rename aikido_zen/context/asgi/{normalize_asgi_headers_test.py => extract_asgi_headers_test.py} (66%) delete mode 100644 aikido_zen/context/asgi/normalize_asgi_headers.py diff --git a/aikido_zen/context/asgi/__init__.py b/aikido_zen/context/asgi/__init__.py index 1fc15f77..bd51f347 100644 --- a/aikido_zen/context/asgi/__init__.py +++ b/aikido_zen/context/asgi/__init__.py @@ -3,7 +3,7 @@ from urllib.parse import parse_qs from aikido_zen.helpers.get_ip_from_request import get_ip_from_request from ..parse_cookies import parse_cookies -from .normalize_asgi_headers import normalize_asgi_headers +from .extract_asgi_headers import extract_asgi_headers from .build_url_from_asgi import build_url_from_asgi @@ -22,7 +22,7 @@ def parse_asgi_scope(scope) -> ASGIContext: This extracts ASGI Scope attributes, described in : https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope """ - headers = normalize_asgi_headers(scope["headers"]) + headers = extract_asgi_headers(scope["headers"]) cookies = {} if "COOKIE" in headers and headers["COOKIE"]: diff --git a/aikido_zen/context/asgi/extract_asgi_headers.py b/aikido_zen/context/asgi/extract_asgi_headers.py new file mode 100644 index 00000000..cadce5fe --- /dev/null +++ b/aikido_zen/context/asgi/extract_asgi_headers.py @@ -0,0 +1,10 @@ +"""Mainly exports extract_asgi_headers""" + +from aikido_zen.helpers.headers import Headers + + +def extract_asgi_headers(headers) -> Headers: + result = Headers() + for k, v in headers: + result.store_header(k.decode("utf-8"), v.decode("utf-8")) + return result diff --git a/aikido_zen/context/asgi/normalize_asgi_headers_test.py b/aikido_zen/context/asgi/extract_asgi_headers_test.py similarity index 66% rename from aikido_zen/context/asgi/normalize_asgi_headers_test.py rename to aikido_zen/context/asgi/extract_asgi_headers_test.py index 0c8f414a..240c3728 100644 --- a/aikido_zen/context/asgi/normalize_asgi_headers_test.py +++ b/aikido_zen/context/asgi/extract_asgi_headers_test.py @@ -1,23 +1,23 @@ import pytest -from .normalize_asgi_headers import normalize_asgi_headers +from .extract_asgi_headers import extract_asgi_headers -def test_normalize_asgi_headers_basic(): +def test_extract_asgi_headers_basic(): headers = [ (b"content-type", b"text/html"), (b"accept-encoding", b"gzip, deflate"), ] expected = {"CONTENT_TYPE": ["text/html"], "ACCEPT_ENCODING": ["gzip, deflate"]} - assert normalize_asgi_headers(headers) == expected + assert extract_asgi_headers(headers) == expected -def test_normalize_asgi_headers_empty(): +def test_extract_asgi_headers_empty(): headers = [] expected = {} - assert normalize_asgi_headers(headers) == expected + assert extract_asgi_headers(headers) == expected -def test_normalize_asgi_headers_with_special_characters(): +def test_extract_asgi_headers_with_special_characters(): headers = [ (b"content-type", b"text/html"), (b"x-custom-header", b"some_value"), @@ -28,19 +28,19 @@ def test_normalize_asgi_headers_with_special_characters(): "X_CUSTOM_HEADER": ["some_value"], "ACCEPT_ENCODING": ["gzip, deflate"], } - assert normalize_asgi_headers(headers) == expected + assert extract_asgi_headers(headers) == expected -def test_normalize_asgi_headers_with_dashes(): +def test_extract_asgi_headers_with_dashes(): headers = [ (b"X-Forwarded-For", b"192.168.1.1"), (b"X-Request-ID", b"abc123"), ] expected = {"X_FORWARDED_FOR": ["192.168.1.1"], "X_REQUEST_ID": ["abc123"]} - assert normalize_asgi_headers(headers) == expected + assert extract_asgi_headers(headers) == expected -def test_normalize_asgi_headers_case_insensitivity(): +def test_extract_asgi_headers_case_insensitivity(): headers = [ (b"Content-Type", b"text/html"), (b"ACCEPT-ENCODING", b"gzip, deflate"), @@ -50,19 +50,19 @@ def test_normalize_asgi_headers_case_insensitivity(): "CONTENT_TYPE": ["text/html"], "ACCEPT_ENCODING": ["gzip, deflate", "json"], } - assert normalize_asgi_headers(headers) == expected + assert extract_asgi_headers(headers) == expected -def test_normalize_asgi_headers_unicode(): +def test_extract_asgi_headers_unicode(): headers = [ (b"content-type", b"text/html"), (b"custom-header", b"value"), ] expected = {"CONTENT_TYPE": ["text/html"], "CUSTOM_HEADER": ["value"]} - assert normalize_asgi_headers(headers) == expected + assert extract_asgi_headers(headers) == expected -def test_normalize_asgi_headers_mixed_case_and_dashes(): +def test_extract_asgi_headers_mixed_case_and_dashes(): headers = [ (b"Content-Type", b"text/html"), (b"X-Custom-Header", b"some_value"), @@ -73,10 +73,10 @@ def test_normalize_asgi_headers_mixed_case_and_dashes(): "X_CUSTOM_HEADER": ["some_value"], "ACCEPT_ENCODING": ["gzip, deflate"], } - assert normalize_asgi_headers(headers) == expected + assert extract_asgi_headers(headers) == expected -def test_normalize_asgi_headers_non_ascii(): +def test_extract_asgi_headers_non_ascii(): headers = [ (b"content-type", b"text/html"), (b"custom-header", b"value"), @@ -87,13 +87,13 @@ def test_normalize_asgi_headers_non_ascii(): "CUSTOM_HEADER": ["value"], "X_HEADER_WITH_EMOJI": ["test"], } - assert normalize_asgi_headers(headers) == expected + assert extract_asgi_headers(headers) == expected -def test_normalize_asgi_headers_large_input(): +def test_extract_asgi_headers_large_input(): headers = [ (f"header-{i}".encode("utf-8"), f"value-{i}".encode("utf-8")) for i in range(1000) ] expected = {f"HEADER_{i}": [f"value-{i}"] for i in range(1000)} - assert normalize_asgi_headers(headers) == expected + assert extract_asgi_headers(headers) == expected diff --git a/aikido_zen/context/asgi/normalize_asgi_headers.py b/aikido_zen/context/asgi/normalize_asgi_headers.py deleted file mode 100644 index c62e7ac2..00000000 --- a/aikido_zen/context/asgi/normalize_asgi_headers.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Mainly exports normalize_asgi_headers""" - - -def normalize_asgi_headers(headers): - """ - Normalizes headers provided by ASGI : - Decodes them, uppercase and underscore keys - """ - parsed_headers = {} - for k, v in headers: - # Normalizing key : decoding, removing dashes and uppercase - key_without_dashes = k.decode("utf-8").replace("-", "_") - key_normalized = key_without_dashes.upper() - if not key_normalized in parsed_headers: - parsed_headers[key_normalized] = list() - parsed_headers[key_normalized].append(v.decode("utf-8")) - return parsed_headers From 8b6e0c1cd0d6fc058e53566812dd2d3e2d6a2c06 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 15:19:26 +0200 Subject: [PATCH 13/21] WSGI use new headers helper class --- aikido_zen/context/wsgi/__init__.py | 5 +++-- aikido_zen/context/wsgi/extract_wsgi_headers.py | 9 ++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/aikido_zen/context/wsgi/__init__.py b/aikido_zen/context/wsgi/__init__.py index d2034be9..4a927ff7 100644 --- a/aikido_zen/context/wsgi/__init__.py +++ b/aikido_zen/context/wsgi/__init__.py @@ -5,12 +5,13 @@ from .extract_wsgi_headers import extract_wsgi_headers from .build_url_from_wsgi import build_url_from_wsgi from ..parse_cookies import parse_cookies +from ...helpers.headers import Headers @dataclass class WSGIContext: method: str - headers: Dict[str, List[str]] + headers: Headers cookies: dict url: str query: dict @@ -22,7 +23,7 @@ def parse_wsgi_environ(environ) -> WSGIContext: This extracts WSGI attributes, described in : https://peps.python.org/pep-3333/#environ-variables """ - headers: Dict[str, List[str]] = extract_wsgi_headers(environ) + headers: Headers = extract_wsgi_headers(environ) # Content type is generally not included as a header, do include this as a header to simplify : if "CONTENT_TYPE" in environ: headers["CONTENT_TYPE"] = [environ["CONTENT_TYPE"]] diff --git a/aikido_zen/context/wsgi/extract_wsgi_headers.py b/aikido_zen/context/wsgi/extract_wsgi_headers.py index 2bcd28a3..5ab7d22c 100644 --- a/aikido_zen/context/wsgi/extract_wsgi_headers.py +++ b/aikido_zen/context/wsgi/extract_wsgi_headers.py @@ -2,12 +2,15 @@ from typing import Dict, List +from aikido_zen.helpers.headers import Headers -def extract_wsgi_headers(request) -> Dict[str, List[str]]: + +def extract_wsgi_headers(request) -> Headers: """Extracts WSGI headers which start with HTTP_ from request dict""" - headers = {} + headers = Headers() for key, value in request.items(): if key.startswith("HTTP_"): # Remove the 'HTTP_' prefix and store in the headers dictionary - headers[key[5:]] = [value] + header_key = key[5:] + headers.store_header(header_key, value) return headers From 6f30d95502c2937d498f28959f5895b50e317fa6 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 15:19:38 +0200 Subject: [PATCH 14/21] context: use new headers helper class --- aikido_zen/context/__init__.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/aikido_zen/context/__init__.py b/aikido_zen/context/__init__.py index ef31a707..56c72142 100644 --- a/aikido_zen/context/__init__.py +++ b/aikido_zen/context/__init__.py @@ -12,6 +12,7 @@ from .wsgi import parse_wsgi_environ, WSGIContext from .asgi import parse_asgi_scope, ASGIContext from .extract_route_params import extract_route_params +from ..helpers.headers import Headers UINPUT_SOURCES = ["body", "cookies", "query", "headers", "xml", "route_params"] current_context = contextvars.ContextVar[Optional["Context"]]( @@ -49,7 +50,7 @@ def __init__(self, context_obj=None, body=None, req=None, source=None): self.parsed_userinput = {} self.xml = {} self.outgoing_req_redirects = [] - self.headers: Dict[str, List[str]] = dict() + self.headers: Headers = Headers() self.query: Dict[str, List[str]] = dict() self.cookies: Dict[str, List[str]] = dict() self.executed_middleware = False @@ -140,10 +141,5 @@ def get_route_metadata(self): "url": self.url, } - def get_header(self, key: str) -> Optional[str]: - if key not in self.headers or not self.headers[key]: - return None - return self.headers[key][-1] - def get_user_agent(self) -> Optional[str]: - return self.get_header("USER_AGENT") + return self.headers.get_header("USER_AGENT") From 3a62e3661551f18e5fd89e6c63c1230313852f80 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 15:21:06 +0200 Subject: [PATCH 15/21] get_ip_from_request use new Headers class --- aikido_zen/helpers/get_ip_from_request.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/aikido_zen/helpers/get_ip_from_request.py b/aikido_zen/helpers/get_ip_from_request.py index 567ef1e6..88b211c4 100644 --- a/aikido_zen/helpers/get_ip_from_request.py +++ b/aikido_zen/helpers/get_ip_from_request.py @@ -6,21 +6,21 @@ import os from typing import Dict, List, Optional +from aikido_zen.helpers.headers import Headers from aikido_zen.helpers.logging import logger -def get_ip_from_request(remote_address, headers: Dict[str, List[str]]) -> Optional[str]: +def get_ip_from_request(remote_address: str, headers: Headers) -> Optional[str]: """ Tries and get the IP address from the request, checking for x-forwarded-for """ - if headers and "X_FORWARDED_FOR" in headers and headers["X_FORWARDED_FOR"]: - if trust_proxy(): - x_forwarded_for = get_client_ip_from_x_forwarded_for( - headers["X_FORWARDED_FOR"][-1] - ) - - if x_forwarded_for and is_ip(x_forwarded_for): - return x_forwarded_for + if headers.get_header("X_FORWARDED_FOR") and trust_proxy(): + x_forwarded_for = get_client_ip_from_x_forwarded_for( + headers.get_header("X_FORWARDED_FOR") + ) + + if x_forwarded_for and is_ip(x_forwarded_for): + return x_forwarded_for if remote_address and is_ip(remote_address): return remote_address From 99129dbad098d24678f2d73e8cc546c1fcbcecfe Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 15:21:14 +0200 Subject: [PATCH 16/21] asgi also return headers type --- aikido_zen/context/asgi/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aikido_zen/context/asgi/__init__.py b/aikido_zen/context/asgi/__init__.py index bd51f347..21802b0d 100644 --- a/aikido_zen/context/asgi/__init__.py +++ b/aikido_zen/context/asgi/__init__.py @@ -5,12 +5,13 @@ from ..parse_cookies import parse_cookies from .extract_asgi_headers import extract_asgi_headers from .build_url_from_asgi import build_url_from_asgi +from ...helpers.headers import Headers @dataclass class ASGIContext: method: str - headers: Dict[str, List[str]] + headers: Headers cookies: dict url: str query: dict From 01e6c559d897652ddf13f54919fb6f478f5d2ae0 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 15:30:44 +0200 Subject: [PATCH 17/21] Fix test cases for get_ip_form_request_test.py --- .../helpers/get_ip_form_request_test.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/aikido_zen/helpers/get_ip_form_request_test.py b/aikido_zen/helpers/get_ip_form_request_test.py index e357552b..c59732da 100644 --- a/aikido_zen/helpers/get_ip_form_request_test.py +++ b/aikido_zen/helpers/get_ip_form_request_test.py @@ -4,36 +4,37 @@ is_ip, get_client_ip_from_x_forwarded_for, ) +from .headers import Headers # Test `get_ip_from_request` function : def test_get_ip_from_request(): # Test case 1: Valid X_FORWARDED_FOR header with valid IP - headers = {"X_FORWARDED_FOR": ["192.168.1.1, 10.0.0.1"]} + headers = Headers() + headers.store_header("x-forwarded-for", "192.168.1.1, 10.0.0.1") assert get_ip_from_request(None, headers) == "192.168.1.1" # Test case 2: Valid X_FORWARDED_FOR header with invalid IPs - headers = {"X_FORWARDED_FOR": ["256.256.256.256, 192.168.1.1"]} + headers = Headers() + headers.store_header("x-forwarded-for", "256.256.256.256, 192.168.1.1") assert ( get_ip_from_request(None, headers) == "192.168.1.1" ) # Should return the valid IP # Test case 3: Valid remote address - headers = {} + headers = Headers() assert get_ip_from_request("10.0.0.1", headers) == "10.0.0.1" # Test case 4: Valid remote address with invalid X_FORWARDED_FOR - headers = {"X_FORWARDED_FOR": ["abc.def.ghi.jkl, 256.256.256.256"]} + headers = Headers() + headers.store_header("x-forwarded-for", "abc.def.ghi.jkl, 256.256.256.256") assert ( get_ip_from_request("10.0.0.1", headers) == "10.0.0.1" ) # Should return the remote address # Test case 5: Both X_FORWARDED_FOR and remote address are invalid - headers = {"X_FORWARDED_FOR": ["abc.def.ghi.jkl, 256.256.256.256"]} - assert get_ip_from_request(None, headers) is None # Should return None - - # Test case 6: Empty headers and remote address - headers = {} + headers = Headers() + headers.store_header("x-forwarded-for", "abc.def.ghi.jkl, 256.256.256.256") assert get_ip_from_request(None, headers) is None # Should return None From aa31af9b83b77d667e78b866c78a80c4993fe6c4 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 15:31:14 +0200 Subject: [PATCH 18/21] rename get_ip_form_request to get_ip_from_request --- .../{get_ip_form_request_test.py => get_ip_from_request_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename aikido_zen/helpers/{get_ip_form_request_test.py => get_ip_from_request_test.py} (100%) diff --git a/aikido_zen/helpers/get_ip_form_request_test.py b/aikido_zen/helpers/get_ip_from_request_test.py similarity index 100% rename from aikido_zen/helpers/get_ip_form_request_test.py rename to aikido_zen/helpers/get_ip_from_request_test.py From 57662c2e05226ccbec28a060d2cab695d7d89051 Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 15:32:56 +0200 Subject: [PATCH 19/21] Fix import in headers.py --- aikido_zen/helpers/headers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aikido_zen/helpers/headers.py b/aikido_zen/helpers/headers.py index c384d62e..1bea4b35 100644 --- a/aikido_zen/helpers/headers.py +++ b/aikido_zen/helpers/headers.py @@ -1,6 +1,4 @@ -from typing import List - -from typing_extensions import Optional +from typing import List, Optional class Headers(dict): From 1e68e552c8c9001b7f9a772e58c418a10c9c087e Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 15:37:54 +0200 Subject: [PATCH 20/21] request_handler test cases now use new Headers class --- aikido_zen/sources/functions/request_handler_test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aikido_zen/sources/functions/request_handler_test.py b/aikido_zen/sources/functions/request_handler_test.py index e13f6bb7..58c322ff 100644 --- a/aikido_zen/sources/functions/request_handler_test.py +++ b/aikido_zen/sources/functions/request_handler_test.py @@ -4,6 +4,7 @@ from .request_handler import request_handler from ...background_process.service_config import ServiceConfig from ...context import Context, current_context +from ...helpers.headers import Headers @pytest.fixture @@ -107,13 +108,15 @@ def test_post_response_no_context(mock_get_comms): # Test firewall lists def set_context(remote_address, user_agent=""): + headers = Headers() + headers.store_header("USER_AGENT", user_agent) Context( context_obj={ "remote_address": remote_address, "method": "POST", "url": "http://localhost:4000", "query": {"abc": "def"}, - "headers": {"USER_AGENT": user_agent}, + "headers": headers, "body": None, "cookies": {}, "source": "flask", From 5ec08e4839f824a5b74f39a245fd1b827d565c3d Mon Sep 17 00:00:00 2001 From: BitterPanda Date: Wed, 2 Jul 2025 15:41:31 +0200 Subject: [PATCH 21/21] fix on_detected_attack test cases by storing value in array --- .../on_detected_attack_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/aikido_zen/background_process/cloud_connection_manager/on_detected_attack_test.py b/aikido_zen/background_process/cloud_connection_manager/on_detected_attack_test.py index ca6ab28d..d8fd6413 100644 --- a/aikido_zen/background_process/cloud_connection_manager/on_detected_attack_test.py +++ b/aikido_zen/background_process/cloud_connection_manager/on_detected_attack_test.py @@ -146,12 +146,12 @@ def test_on_detected_attack_request_data_and_attack_data( assert request_data["ipAddress"] == "198.51.100.23" assert request_data["body"] == 123 assert request_data["headers"] == { - "CONTENT_TYPE": "application/json", - "USER_AGENT": "Mozilla/5.0", - "COOKIE": "sessionId=abc123xyz456;", - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "HOST": "localhost:8080", + "CONTENT_TYPE": ["application/json"], + "USER_AGENT": ["Mozilla/5.0"], + "COOKIE": ["sessionId=abc123xyz456;"], + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "HOST": ["localhost:8080"], } assert request_data["source"] == "django" assert request_data["route"] == "/hello"