diff --git a/aikido_zen/api_discovery/get_api_info_test.py b/aikido_zen/api_discovery/get_api_info_test.py index e293a367d..58f796074 100644 --- a/aikido_zen/api_discovery/get_api_info_test.py +++ b/aikido_zen/api_discovery/get_api_info_test.py @@ -1,5 +1,6 @@ import pytest from .get_api_info import get_api_info +from ..helpers.headers import Headers class Context: @@ -16,7 +17,8 @@ def __init__( self.path = path self.body = body self.xml = xml - self.headers = {"CONTENT_TYPE": content_type} + self.headers = Headers() + self.headers.store_header("CONTENT_TYPE", content_type) self.query = query self.cookies = {} @@ -154,7 +156,7 @@ def test_auth_get_api_info(monkeypatch): }, content_type="application/json", ) - context1.headers["AUTHORIZATION"] = "Bearer token" + context1.headers.store_header("AUTHORIZATION", "Bearer token") api_info = get_api_info(context1) assert api_info == { "body": { diff --git a/aikido_zen/api_discovery/get_auth_type_test.py b/aikido_zen/api_discovery/get_auth_type_test.py index 48865d2bc..ad3e51cc2 100644 --- a/aikido_zen/api_discovery/get_auth_type_test.py +++ b/aikido_zen/api_discovery/get_auth_type_test.py @@ -1,10 +1,13 @@ import pytest from .get_auth_types import get_auth_types +from ..helpers.headers import Headers class Context: def __init__(self, headers={}, cookies={}): - self.headers = headers + self.headers = Headers() + for k, v in headers.items(): + self.headers.store_header(k, v) self.cookies = cookies diff --git a/aikido_zen/api_discovery/get_auth_types.py b/aikido_zen/api_discovery/get_auth_types.py index 30263d259..9312829c1 100644 --- a/aikido_zen/api_discovery/get_auth_types.py +++ b/aikido_zen/api_discovery/get_auth_types.py @@ -1,5 +1,6 @@ """Usefull export : get_auth_types""" +from aikido_zen.helpers.headers import Headers from aikido_zen.helpers.is_http_auth_scheme import is_http_auth_scheme common_api_key_header_names = [ @@ -26,13 +27,13 @@ def get_auth_types(context): """Get the authentication type of the API request.""" - if not isinstance(context.headers, dict): + if not isinstance(context.headers, Headers): return None result = [] # Check the Authorization header - auth_header = context.headers.get("AUTHORIZATION") + auth_header = context.headers.get_header("AUTHORIZATION") if isinstance(auth_header, str): auth_header_type = get_authorization_header_type(auth_header) if auth_header_type: diff --git a/aikido_zen/api_discovery/get_body_data_type.py b/aikido_zen/api_discovery/get_body_data_type.py index c4448a038..a7a03a8ad 100644 --- a/aikido_zen/api_discovery/get_body_data_type.py +++ b/aikido_zen/api_discovery/get_body_data_type.py @@ -1,5 +1,7 @@ """Exports get_body_data_type""" +from aikido_zen.helpers.headers import Headers + JSON_CONTENT_TYPES = [ "application/json", "application/vnd.api+json", @@ -8,12 +10,12 @@ ] -def get_body_data_type(headers): +def get_body_data_type(headers: Headers): """Gets the type of body data from headers""" - if not isinstance(headers, dict) or headers is None: + if not isinstance(headers, Headers) or headers is None: return - content_type = headers.get("CONTENT_TYPE") + content_type = headers.get_header("CONTENT_TYPE") if not content_type: return diff --git a/aikido_zen/api_discovery/get_body_data_type_test.py b/aikido_zen/api_discovery/get_body_data_type_test.py index 7ae144b1a..1b0d8870e 100644 --- a/aikido_zen/api_discovery/get_body_data_type_test.py +++ b/aikido_zen/api_discovery/get_body_data_type_test.py @@ -1,23 +1,51 @@ import pytest from .get_body_data_type import get_body_data_type +from ..helpers.headers import Headers def test_get_body_data_type(): - assert get_body_data_type({"CONTENT_TYPE": "application/json"}) == "json" - assert get_body_data_type({"CONTENT_TYPE": "application/vnd.api+json"}) == "json" - assert get_body_data_type({"CONTENT_TYPE": "application/csp-report"}) == "json" - assert get_body_data_type({"CONTENT_TYPE": "application/x-json"}) == "json" - assert ( - get_body_data_type({"CONTENT_TYPE": "application/x-www-form-urlencoded"}) - == "form-urlencoded" - ) - assert get_body_data_type({"CONTENT_TYPE": "multipart/form-data"}) == "form-data" - assert get_body_data_type({"CONTENT_TYPE": "text/xml"}) == "xml" - assert get_body_data_type({"CONTENT_TYPE": "text/html"}) is None - assert ( - get_body_data_type({"CONTENT_TYPE": ["application/json", "text/html"]}) - == "json" - ) - assert get_body_data_type({"x-test": "abc"}) is None - assert get_body_data_type(None) is None # Testing invalid input - assert get_body_data_type({}) is None + headers = Headers() + headers.store_header("CONTENT_TYPE", "application/json") + assert get_body_data_type(headers) == "json" + + headers = Headers() + headers.store_header("CONTENT_TYPE", "application/vnd.api+json") + assert get_body_data_type(headers) == "json" + + headers = Headers() + headers.store_header("CONTENT_TYPE", "application/csp-report") + assert get_body_data_type(headers) == "json" + + headers = Headers() + headers.store_header("CONTENT_TYPE", "application/x-json") + assert get_body_data_type(headers) == "json" + + headers = Headers() + headers.store_header("CONTENT_TYPE", "application/x-www-form-urlencoded") + assert get_body_data_type(headers) == "form-urlencoded" + + headers = Headers() + headers.store_header("CONTENT_TYPE", "multipart/form-data") + assert get_body_data_type(headers) == "form-data" + + headers = Headers() + headers.store_header("CONTENT_TYPE", "text/xml") + assert get_body_data_type(headers) == "xml" + + headers = Headers() + headers.store_header("CONTENT_TYPE", "text/html") + assert get_body_data_type(headers) is None + + headers = Headers() + headers.store_header("CONTENT_TYPE", "application/json, text/html") + assert get_body_data_type(headers) == "json" + + headers = Headers() + headers.store_header("x-test", "abc") + assert get_body_data_type(headers) is None + + headers = Headers() + assert get_body_data_type(headers) is None # Testing invalid input + + headers = Headers() + assert get_body_data_type(headers) is None diff --git a/aikido_zen/background_process/cloud_connection_manager/on_detected_attack_test.py b/aikido_zen/background_process/cloud_connection_manager/on_detected_attack_test.py index ca6ab28df..d8fd6413e 100644 --- a/aikido_zen/background_process/cloud_connection_manager/on_detected_attack_test.py +++ b/aikido_zen/background_process/cloud_connection_manager/on_detected_attack_test.py @@ -146,12 +146,12 @@ def test_on_detected_attack_request_data_and_attack_data( assert request_data["ipAddress"] == "198.51.100.23" assert request_data["body"] == 123 assert request_data["headers"] == { - "CONTENT_TYPE": "application/json", - "USER_AGENT": "Mozilla/5.0", - "COOKIE": "sessionId=abc123xyz456;", - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "HOST": "localhost:8080", + "CONTENT_TYPE": ["application/json"], + "USER_AGENT": ["Mozilla/5.0"], + "COOKIE": ["sessionId=abc123xyz456;"], + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "HOST": ["localhost:8080"], } assert request_data["source"] == "django" assert request_data["route"] == "/hello" diff --git a/aikido_zen/background_process/routes/init_test.py b/aikido_zen/background_process/routes/init_test.py index 44471f623..ee1ea13ff 100644 --- a/aikido_zen/background_process/routes/init_test.py +++ b/aikido_zen/background_process/routes/init_test.py @@ -1,5 +1,6 @@ from aikido_zen.background_process.routes import Routes from aikido_zen.api_discovery.get_api_info import get_api_info +from aikido_zen.helpers.headers import Headers class Context: @@ -18,8 +19,11 @@ def __init__( self.route = path self.body = body self.xml = xml - self.headers = headers - self.headers["CONTENT_TYPE"] = content_type + self.raw_headers = headers + self.raw_headers["CONTENT_TYPE"] = content_type + self.headers = Headers() + for k, v in self.raw_headers.items(): + self.headers.store_header(k, v) self.query = query self.cookies = cookies diff --git a/aikido_zen/context/__init__.py b/aikido_zen/context/__init__.py index af541ecdb..83b9acb26 100644 --- a/aikido_zen/context/__init__.py +++ b/aikido_zen/context/__init__.py @@ -14,6 +14,7 @@ from .wsgi import set_wsgi_attributes_on_context from .asgi import set_asgi_attributes_on_context from .extract_route_params import extract_route_params +from ..helpers.headers import Headers UINPUT_SOURCES = ["body", "cookies", "query", "headers", "xml", "route_params"] current_context = contextvars.ContextVar("current_context", default=None) @@ -48,11 +49,12 @@ def __init__(self, context_obj=None, body=None, req=None, source=None): self.xml = {} self.outgoing_req_redirects = [] self.set_body(body) + self.headers: Headers = Headers() + self.cookies = dict() + self.query = dict() # Parse WSGI/ASGI/... request : - self.cookies = self.method = self.remote_address = self.query = self.headers = ( - self.url - ) = None + self.method = self.remote_address = self.url = None if source in WSGI_SOURCES: set_wsgi_attributes_on_context(self, req) elif source in ASGI_SOURCES: @@ -128,8 +130,4 @@ def get_route_metadata(self): } def get_user_agent(self): - if "USER_AGENT" not in self.headers: - return None - if isinstance(self.headers["USER_AGENT"], list): - return self.headers["USER_AGENT"][-1] - return self.headers["USER_AGENT"] + return self.headers.get_header("USER_AGENT") diff --git a/aikido_zen/context/asgi/__init__.py b/aikido_zen/context/asgi/__init__.py index d083ece0f..0b6beaf12 100644 --- a/aikido_zen/context/asgi/__init__.py +++ b/aikido_zen/context/asgi/__init__.py @@ -17,10 +17,10 @@ def set_asgi_attributes_on_context(context, scope): context.method = scope["method"] context.headers = normalize_asgi_headers(scope["headers"]) - if "COOKIE" in context.headers and len(context.headers["COOKIE"]) > 0: + if context.headers.get_header("COOKIE"): # Right now just use the first Cookie header, will change later to use # framework definition for cookies. - context.cookies = parse_cookies(context.headers["COOKIE"][0]) + context.cookies = parse_cookies(context.headers.get_header("COOKIE")) else: context.cookies = {} diff --git a/aikido_zen/context/asgi/normalize_asgi_headers.py b/aikido_zen/context/asgi/normalize_asgi_headers.py index c62e7ac2b..915a753f5 100644 --- a/aikido_zen/context/asgi/normalize_asgi_headers.py +++ b/aikido_zen/context/asgi/normalize_asgi_headers.py @@ -1,17 +1,14 @@ """Mainly exports normalize_asgi_headers""" +from aikido_zen.helpers.headers import Headers -def normalize_asgi_headers(headers): + +def normalize_asgi_headers(headers) -> Headers: """ Normalizes headers provided by ASGI : Decodes them, uppercase and underscore keys """ - parsed_headers = {} + result = Headers() for k, v in headers: - # Normalizing key : decoding, removing dashes and uppercase - key_without_dashes = k.decode("utf-8").replace("-", "_") - key_normalized = key_without_dashes.upper() - if not key_normalized in parsed_headers: - parsed_headers[key_normalized] = list() - parsed_headers[key_normalized].append(v.decode("utf-8")) - return parsed_headers + result.store_header(k.decode("utf-8"), v.decode("utf-8")) + return result diff --git a/aikido_zen/context/init_test.py b/aikido_zen/context/init_test.py index 4769167ba..b9118337a 100644 --- a/aikido_zen/context/init_test.py +++ b/aikido_zen/context/init_test.py @@ -52,11 +52,11 @@ def test_wsgi_context_1(): "source": "django", "method": "POST", "headers": { - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "COOKIE": "sessionId=abc123xyz456;", - "HOST": "example.com", - "CONTENT_TYPE": "application/x-www-form-urlencoded", + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "COOKIE": ["sessionId=abc123xyz456;"], + "HOST": ["example.com"], + "CONTENT_TYPE": ["application/x-www-form-urlencoded"], }, "cookies": {"sessionId": "abc123xyz456"}, "url": "https://example.com/hello", @@ -81,12 +81,12 @@ def test_wsgi_context_2(): "source": "flask", "method": "GET", "headers": { - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "COOKIE": "sessionId=abc123xyz456;", - "HOST": "localhost:8080", - "CONTENT_TYPE": "application/json", - "USER_AGENT": "Mozilla/5.0", + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "COOKIE": ["sessionId=abc123xyz456;"], + "HOST": ["localhost:8080"], + "CONTENT_TYPE": ["application/json"], + "USER_AGENT": ["Mozilla/5.0"], }, "cookies": {"sessionId": "abc123xyz456"}, "url": "http://localhost:8080/hello", @@ -129,12 +129,12 @@ def test_context_is_picklable(mocker): assert unpickled_obj.url == "http://localhost:8080/hello" assert unpickled_obj.body == 123 assert unpickled_obj.headers == { - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "COOKIE": "sessionId=abc123xyz456;", - "HOST": "localhost:8080", - "CONTENT_TYPE": "application/json", - "USER_AGENT": "Mozilla/5.0", + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "COOKIE": ["sessionId=abc123xyz456;"], + "HOST": ["localhost:8080"], + "CONTENT_TYPE": ["application/json"], + "USER_AGENT": ["Mozilla/5.0"], } assert unpickled_obj.query == {"user": ["JohnDoe"], "age": ["30", "35"]} assert unpickled_obj.cookies == {"sessionId": "abc123xyz456"} diff --git a/aikido_zen/context/wsgi/__init__.py b/aikido_zen/context/wsgi/__init__.py index 6c0f5c895..f90a47369 100644 --- a/aikido_zen/context/wsgi/__init__.py +++ b/aikido_zen/context/wsgi/__init__.py @@ -17,8 +17,8 @@ def set_wsgi_attributes_on_context(context, environ): context.method = environ["REQUEST_METHOD"] context.headers = extract_wsgi_headers(environ) - if "COOKIE" in context.headers: - context.cookies = parse_cookies(context.headers["COOKIE"]) + if context.headers.get_header("COOKIE"): + context.cookies = parse_cookies(context.headers.get_header("COOKIE")) else: context.cookies = {} context.url = build_url_from_wsgi(environ) @@ -29,4 +29,4 @@ def set_wsgi_attributes_on_context(context, environ): # Content type is generally not included as a header, do include this as a header to simplify : if "CONTENT_TYPE" in environ: - context.headers["CONTENT_TYPE"] = environ["CONTENT_TYPE"] + context.headers.store_header("CONTENT_TYPE", environ["CONTENT_TYPE"]) diff --git a/aikido_zen/context/wsgi/extract_wsgi_headers.py b/aikido_zen/context/wsgi/extract_wsgi_headers.py index c20cc3ae4..b13cfee4c 100644 --- a/aikido_zen/context/wsgi/extract_wsgi_headers.py +++ b/aikido_zen/context/wsgi/extract_wsgi_headers.py @@ -1,11 +1,14 @@ """Exports function extract_wsgi_headers""" +from aikido_zen.helpers.headers import Headers -def extract_wsgi_headers(request): + +def extract_wsgi_headers(request) -> Headers: """Extracts WSGI headers which start with HTTP_ from request dict""" - headers = {} + headers = Headers() for key, value in request.items(): if key.startswith("HTTP_"): # Remove the 'HTTP_' prefix and store in the headers dictionary - headers[key[5:]] = value + header_key = key[5:] + headers.store_header(header_key, value) return headers diff --git a/aikido_zen/context/wsgi/extract_wsgi_headers_test.py b/aikido_zen/context/wsgi/extract_wsgi_headers_test.py index 85fc478dc..6336f3a73 100644 --- a/aikido_zen/context/wsgi/extract_wsgi_headers_test.py +++ b/aikido_zen/context/wsgi/extract_wsgi_headers_test.py @@ -4,7 +4,7 @@ def test_extract_wsgi_headers_single_header(): request = {"REQUEST_METHOD": "GET", "HTTP_USER_AGENT": "Mozilla/5.0"} - expected = {"USER_AGENT": "Mozilla/5.0"} + expected = {"USER_AGENT": ["Mozilla/5.0"]} assert extract_wsgi_headers(request) == expected @@ -16,9 +16,9 @@ def test_extract_wsgi_headers_multiple_headers(): "HTTP_CONTENT_TYPE": "application/json", } expected = { - "HOST": "example.com", - "ACCEPT": "text/html", - "CONTENT_TYPE": "application/json", + "HOST": ["example.com"], + "ACCEPT": ["text/html"], + "CONTENT_TYPE": ["application/json"], } assert extract_wsgi_headers(request) == expected @@ -42,5 +42,5 @@ def test_extract_wsgi_headers_mixed_headers(): "HTTP_ACCEPT_LANGUAGE": "en-US,en;q=0.5", "OTHER_HEADER": "value", } - expected = {"USER_AGENT": "Mozilla/5.0", "ACCEPT_LANGUAGE": "en-US,en;q=0.5"} + expected = {"USER_AGENT": ["Mozilla/5.0"], "ACCEPT_LANGUAGE": ["en-US,en;q=0.5"]} assert extract_wsgi_headers(request) == expected diff --git a/aikido_zen/helpers/get_ip_from_request.py b/aikido_zen/helpers/get_ip_from_request.py index e67a0a306..eb46a8d47 100644 --- a/aikido_zen/helpers/get_ip_from_request.py +++ b/aikido_zen/helpers/get_ip_from_request.py @@ -4,22 +4,23 @@ import socket import os +from typing import Dict, List, Optional + +from aikido_zen.helpers.headers import Headers from aikido_zen.helpers.logging import logger -def get_ip_from_request(remote_address, headers): +def get_ip_from_request(remote_address: str, headers: Headers) -> Optional[str]: """ Tries and get the IP address from the request, checking for x-forwarded-for """ - if headers: - lower_headers = {key.lower(): value for key, value in headers.items()} - if "x_forwarded_for" in lower_headers and trust_proxy(): - x_forwarded_for = get_client_ip_from_x_forwarded_for( - lower_headers["x_forwarded_for"] - ) - - if x_forwarded_for and is_ip(x_forwarded_for): - return x_forwarded_for + if headers.get_header("X_FORWARDED_FOR") and trust_proxy(): + x_forwarded_for = get_client_ip_from_x_forwarded_for( + headers.get_header("X_FORWARDED_FOR") + ) + + if x_forwarded_for and is_ip(x_forwarded_for): + return x_forwarded_for if remote_address and is_ip(remote_address): return remote_address diff --git a/aikido_zen/helpers/get_ip_form_request_test.py b/aikido_zen/helpers/get_ip_from_request_test.py similarity index 87% rename from aikido_zen/helpers/get_ip_form_request_test.py rename to aikido_zen/helpers/get_ip_from_request_test.py index 6935ea89c..c59732daa 100644 --- a/aikido_zen/helpers/get_ip_form_request_test.py +++ b/aikido_zen/helpers/get_ip_from_request_test.py @@ -4,36 +4,37 @@ is_ip, get_client_ip_from_x_forwarded_for, ) +from .headers import Headers # Test `get_ip_from_request` function : def test_get_ip_from_request(): # Test case 1: Valid X_FORWARDED_FOR header with valid IP - headers = {"X_FORWARDED_FOR": "192.168.1.1, 10.0.0.1"} + headers = Headers() + headers.store_header("x-forwarded-for", "192.168.1.1, 10.0.0.1") assert get_ip_from_request(None, headers) == "192.168.1.1" # Test case 2: Valid X_FORWARDED_FOR header with invalid IPs - headers = {"X_FORWARDED_FOR": "256.256.256.256, 192.168.1.1"} + headers = Headers() + headers.store_header("x-forwarded-for", "256.256.256.256, 192.168.1.1") assert ( get_ip_from_request(None, headers) == "192.168.1.1" ) # Should return the valid IP # Test case 3: Valid remote address - headers = {} + headers = Headers() assert get_ip_from_request("10.0.0.1", headers) == "10.0.0.1" # Test case 4: Valid remote address with invalid X_FORWARDED_FOR - headers = {"X_FORWARDED_FOR": "abc.def.ghi.jkl, 256.256.256.256"} + headers = Headers() + headers.store_header("x-forwarded-for", "abc.def.ghi.jkl, 256.256.256.256") assert ( get_ip_from_request("10.0.0.1", headers) == "10.0.0.1" ) # Should return the remote address # Test case 5: Both X_FORWARDED_FOR and remote address are invalid - headers = {"X_FORWARDED_FOR": "abc.def.ghi.jkl, 256.256.256.256"} - assert get_ip_from_request(None, headers) is None # Should return None - - # Test case 6: Empty headers and remote address - headers = {} + headers = Headers() + headers.store_header("x-forwarded-for", "abc.def.ghi.jkl, 256.256.256.256") assert get_ip_from_request(None, headers) is None # Should return None diff --git a/aikido_zen/helpers/headers.py b/aikido_zen/helpers/headers.py new file mode 100644 index 000000000..1bea4b350 --- /dev/null +++ b/aikido_zen/helpers/headers.py @@ -0,0 +1,33 @@ +from typing import List, Optional + + +class Headers(dict): + def store_headers(self, key: str, values: List[str]): + normalized_key = self.normalize_header_key(key) + if self.get(normalized_key, []): + self[normalized_key] += values + else: + self[normalized_key] = values + + def store_header(self, key: str, value: str): + self.store_headers(key, [value]) + + def get_header(self, key: str) -> Optional[str]: + self.validate_header_key(key) + if self.get(key, []): + return self.get(key)[-1] + return None + + @staticmethod + def validate_header_key(key: str): + if not key.isupper(): + raise ValueError("Header key must be uppercase.") + if "-" in key: + raise ValueError("Header key must use underscores instead of dashes.") + + @staticmethod + def normalize_header_key(key: str) -> str: + result = str(key) + result = result.replace("-", "_") + result = result.upper() + return result diff --git a/aikido_zen/helpers/headers_test.py b/aikido_zen/helpers/headers_test.py new file mode 100644 index 000000000..69b2dcd9a --- /dev/null +++ b/aikido_zen/helpers/headers_test.py @@ -0,0 +1,72 @@ +import pytest +from .headers import Headers + + +def test_store_header_single_value(): + headers = Headers() + headers.store_header("CONTENT_TYPE", "application/json") + assert headers["CONTENT_TYPE"] == ["application/json"] + + +def test_store_header_multiple_values(): + headers = Headers() + headers.store_headers("CONTENT_TYPE", ["application/json"]) + headers.store_headers("CONTENT_TYPE", ["text/html"]) + assert headers["CONTENT_TYPE"] == ["application/json", "text/html"] + + +def test_get_header_existing_key(): + headers = Headers() + headers.store_header("CONTENT_TYPE", "application/json") + assert headers.get_header("CONTENT_TYPE") == "application/json" + + +def test_get_header_non_existing_key(): + headers = Headers() + assert headers.get_header("NON_EXISTING_KEY") is None + + +def test_get_header_multiple_values(): + headers = Headers() + headers.store_headers("CONTENT_TYPE", ["application/json", "text/html"]) + assert headers.get_header("CONTENT_TYPE") == "text/html" + + +def test_validate_header_key_valid(): + try: + Headers.validate_header_key("VALID_HEADER") + except ValueError: + pytest.fail("validate_header_key raised ValueError unexpectedly!") + + +def test_validate_header_key_not_uppercase(): + with pytest.raises(ValueError, match="Header key must be uppercase."): + Headers.validate_header_key("invalid_header") + + +def test_validate_header_key_with_dash(): + with pytest.raises( + ValueError, match="Header key must use underscores instead of dashes." + ): + Headers.validate_header_key("INVALID-HEADER") + + +def test_normalize_header_key_valid(): + assert Headers.normalize_header_key("valid-header") == "VALID_HEADER" + assert Headers.normalize_header_key("ANOTHER-HEADER") == "ANOTHER_HEADER" + + +def test_normalize_header_key_already_normalized(): + assert Headers.normalize_header_key("ALREADY_NORMALIZED") == "ALREADY_NORMALIZED" + + +def test_store_headers_with_empty_list(): + headers = Headers() + headers.store_headers("CONTENT_TYPE", []) + assert headers.get_header("CONTENT_TYPE") is None + + +def test_store_header_with_empty_string(): + headers = Headers() + headers.store_header("CONTENT_TYPE", "") + assert headers.get_header("CONTENT_TYPE") == "" diff --git a/aikido_zen/sources/flask/flask_test.py b/aikido_zen/sources/flask/flask_test.py index 4a1d448ab..9d25f765a 100644 --- a/aikido_zen/sources/flask/flask_test.py +++ b/aikido_zen/sources/flask/flask_test.py @@ -104,11 +104,11 @@ def hello(user, age): assert get_current_context().method == "POST" assert get_current_context().body is None assert get_current_context().headers == { - "COOKIE": "sessionId=abc123xyz456;", - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "HOST": "example.com", - "CONTENT_TYPE": "application/json", + "COOKIE": ["sessionId=abc123xyz456;"], + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "HOST": ["example.com"], + "CONTENT_TYPE": ["application/json"], } calls = mock_request_handler.call_args_list assert len(calls) == 2 @@ -175,11 +175,11 @@ def test_flask_all_3_func_with_invalid_body(): get_current_context().body == None ) # body is None since it's invalid json assert get_current_context().headers == { - "COOKIE": "sessionId=abc123xyz456;", - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "HOST": "example.com", - "CONTENT_TYPE": "application/json", + "COOKIE": ["sessionId=abc123xyz456;"], + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "HOST": ["example.com"], + "CONTENT_TYPE": ["application/json"], } calls = mock_request_handler.call_args_list assert len(calls) == 2 @@ -208,11 +208,11 @@ def test_flask_all_3_func(): assert get_current_context().method == "POST" assert get_current_context().body == None assert get_current_context().headers == { - "COOKIE": "sessionId=abc123xyz456;", - "HEADER_1": "header 1 value", - "HEADER_2": "Header 2 value", - "HOST": "example.com", - "CONTENT_TYPE": "application/x-www-form-urlencoded", + "COOKIE": ["sessionId=abc123xyz456;"], + "HEADER_1": ["header 1 value"], + "HEADER_2": ["Header 2 value"], + "HOST": ["example.com"], + "CONTENT_TYPE": ["application/x-www-form-urlencoded"], } calls = mock_request_handler.call_args_list assert len(calls) == 2 diff --git a/aikido_zen/sources/functions/request_handler_test.py b/aikido_zen/sources/functions/request_handler_test.py index e13f6bb74..58c322ffb 100644 --- a/aikido_zen/sources/functions/request_handler_test.py +++ b/aikido_zen/sources/functions/request_handler_test.py @@ -4,6 +4,7 @@ from .request_handler import request_handler from ...background_process.service_config import ServiceConfig from ...context import Context, current_context +from ...helpers.headers import Headers @pytest.fixture @@ -107,13 +108,15 @@ def test_post_response_no_context(mock_get_comms): # Test firewall lists def set_context(remote_address, user_agent=""): + headers = Headers() + headers.store_header("USER_AGENT", user_agent) Context( context_obj={ "remote_address": remote_address, "method": "POST", "url": "http://localhost:4000", "query": {"abc": "def"}, - "headers": {"USER_AGENT": user_agent}, + "headers": headers, "body": None, "cookies": {}, "source": "flask",