Skip to content

Create new Headers class, so that the codebase is consistent in handling headers #422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions aikido_zen/api_discovery/get_api_info_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from .get_api_info import get_api_info
from ..helpers.headers import Headers


class Context:
Expand All @@ -16,7 +17,8 @@ def __init__(
self.path = path
self.body = body
self.xml = xml
self.headers = {"CONTENT_TYPE": content_type}
self.headers = Headers()
self.headers.store_header("CONTENT_TYPE", content_type)
self.query = query
self.cookies = {}

Expand Down Expand Up @@ -154,7 +156,7 @@ def test_auth_get_api_info(monkeypatch):
},
content_type="application/json",
)
context1.headers["AUTHORIZATION"] = "Bearer token"
context1.headers.store_header("AUTHORIZATION", "Bearer token")
api_info = get_api_info(context1)
assert api_info == {
"body": {
Expand Down
5 changes: 4 additions & 1 deletion aikido_zen/api_discovery/get_auth_type_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import pytest
from .get_auth_types import get_auth_types
from ..helpers.headers import Headers


class Context:
def __init__(self, headers={}, cookies={}):
self.headers = headers
self.headers = Headers()
for k, v in headers.items():
self.headers.store_header(k, v)
self.cookies = cookies


Expand Down
5 changes: 3 additions & 2 deletions aikido_zen/api_discovery/get_auth_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Usefull export : get_auth_types"""

from aikido_zen.helpers.headers import Headers
from aikido_zen.helpers.is_http_auth_scheme import is_http_auth_scheme

common_api_key_header_names = [
Expand All @@ -26,13 +27,13 @@

def get_auth_types(context):
"""Get the authentication type of the API request."""
if not isinstance(context.headers, dict):
if not isinstance(context.headers, Headers):
return None

result = []

# Check the Authorization header
auth_header = context.headers.get("AUTHORIZATION")
auth_header = context.headers.get_header("AUTHORIZATION")
if isinstance(auth_header, str):
auth_header_type = get_authorization_header_type(auth_header)
if auth_header_type:
Expand Down
8 changes: 5 additions & 3 deletions aikido_zen/api_discovery/get_body_data_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Exports get_body_data_type"""

from aikido_zen.helpers.headers import Headers

JSON_CONTENT_TYPES = [
"application/json",
"application/vnd.api+json",
Expand All @@ -8,12 +10,12 @@
]


def get_body_data_type(headers):
def get_body_data_type(headers: Headers):
"""Gets the type of body data from headers"""
if not isinstance(headers, dict) or headers is None:
if not isinstance(headers, Headers) or headers is None:
return

content_type = headers.get("CONTENT_TYPE")
content_type = headers.get_header("CONTENT_TYPE")
if not content_type:
return

Expand Down
64 changes: 46 additions & 18 deletions aikido_zen/api_discovery/get_body_data_type_test.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,51 @@
import pytest
from .get_body_data_type import get_body_data_type
from ..helpers.headers import Headers


def test_get_body_data_type():
assert get_body_data_type({"CONTENT_TYPE": "application/json"}) == "json"
assert get_body_data_type({"CONTENT_TYPE": "application/vnd.api+json"}) == "json"
assert get_body_data_type({"CONTENT_TYPE": "application/csp-report"}) == "json"
assert get_body_data_type({"CONTENT_TYPE": "application/x-json"}) == "json"
assert (
get_body_data_type({"CONTENT_TYPE": "application/x-www-form-urlencoded"})
== "form-urlencoded"
)
assert get_body_data_type({"CONTENT_TYPE": "multipart/form-data"}) == "form-data"
assert get_body_data_type({"CONTENT_TYPE": "text/xml"}) == "xml"
assert get_body_data_type({"CONTENT_TYPE": "text/html"}) is None
assert (
get_body_data_type({"CONTENT_TYPE": ["application/json", "text/html"]})
== "json"
)
assert get_body_data_type({"x-test": "abc"}) is None
assert get_body_data_type(None) is None # Testing invalid input
assert get_body_data_type({}) is None
headers = Headers()
headers.store_header("CONTENT_TYPE", "application/json")
assert get_body_data_type(headers) == "json"

headers = Headers()
headers.store_header("CONTENT_TYPE", "application/vnd.api+json")
assert get_body_data_type(headers) == "json"

headers = Headers()
headers.store_header("CONTENT_TYPE", "application/csp-report")
assert get_body_data_type(headers) == "json"

headers = Headers()
headers.store_header("CONTENT_TYPE", "application/x-json")
assert get_body_data_type(headers) == "json"

headers = Headers()
headers.store_header("CONTENT_TYPE", "application/x-www-form-urlencoded")
assert get_body_data_type(headers) == "form-urlencoded"

headers = Headers()
headers.store_header("CONTENT_TYPE", "multipart/form-data")
assert get_body_data_type(headers) == "form-data"

headers = Headers()
headers.store_header("CONTENT_TYPE", "text/xml")
assert get_body_data_type(headers) == "xml"

headers = Headers()
headers.store_header("CONTENT_TYPE", "text/html")
assert get_body_data_type(headers) is None

headers = Headers()
headers.store_header("CONTENT_TYPE", "application/json, text/html")
assert get_body_data_type(headers) == "json"

headers = Headers()
headers.store_header("x-test", "abc")
assert get_body_data_type(headers) is None

headers = Headers()
assert get_body_data_type(headers) is None # Testing invalid input

headers = Headers()
assert get_body_data_type(headers) is None
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ def test_on_detected_attack_request_data_and_attack_data(
assert request_data["ipAddress"] == "198.51.100.23"
assert request_data["body"] == 123
assert request_data["headers"] == {
"CONTENT_TYPE": "application/json",
"USER_AGENT": "Mozilla/5.0",
"COOKIE": "sessionId=abc123xyz456;",
"HEADER_1": "header 1 value",
"HEADER_2": "Header 2 value",
"HOST": "localhost:8080",
"CONTENT_TYPE": ["application/json"],
"USER_AGENT": ["Mozilla/5.0"],
"COOKIE": ["sessionId=abc123xyz456;"],
"HEADER_1": ["header 1 value"],
"HEADER_2": ["Header 2 value"],
"HOST": ["localhost:8080"],
}
assert request_data["source"] == "django"
assert request_data["route"] == "/hello"
Expand Down
8 changes: 6 additions & 2 deletions aikido_zen/background_process/routes/init_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from aikido_zen.background_process.routes import Routes
from aikido_zen.api_discovery.get_api_info import get_api_info
from aikido_zen.helpers.headers import Headers


class Context:
Expand All @@ -18,8 +19,11 @@ def __init__(
self.route = path
self.body = body
self.xml = xml
self.headers = headers
self.headers["CONTENT_TYPE"] = content_type
self.raw_headers = headers
self.raw_headers["CONTENT_TYPE"] = content_type
self.headers = Headers()
for k, v in self.raw_headers.items():
self.headers.store_header(k, v)
self.query = query
self.cookies = cookies

Expand Down
14 changes: 6 additions & 8 deletions aikido_zen/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .wsgi import set_wsgi_attributes_on_context
from .asgi import set_asgi_attributes_on_context
from .extract_route_params import extract_route_params
from ..helpers.headers import Headers

UINPUT_SOURCES = ["body", "cookies", "query", "headers", "xml", "route_params"]
current_context = contextvars.ContextVar("current_context", default=None)
Expand Down Expand Up @@ -48,11 +49,12 @@ def __init__(self, context_obj=None, body=None, req=None, source=None):
self.xml = {}
self.outgoing_req_redirects = []
self.set_body(body)
self.headers: Headers = Headers()
self.cookies = dict()
self.query = dict()

# Parse WSGI/ASGI/... request :
self.cookies = self.method = self.remote_address = self.query = self.headers = (
self.url
) = None
self.method = self.remote_address = self.url = None
if source in WSGI_SOURCES:
set_wsgi_attributes_on_context(self, req)
elif source in ASGI_SOURCES:
Expand Down Expand Up @@ -128,8 +130,4 @@ def get_route_metadata(self):
}

def get_user_agent(self):
if "USER_AGENT" not in self.headers:
return None
if isinstance(self.headers["USER_AGENT"], list):
return self.headers["USER_AGENT"][-1]
return self.headers["USER_AGENT"]
return self.headers.get_header("USER_AGENT")
4 changes: 2 additions & 2 deletions aikido_zen/context/asgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def set_asgi_attributes_on_context(context, scope):
context.method = scope["method"]
context.headers = normalize_asgi_headers(scope["headers"])

if "COOKIE" in context.headers and len(context.headers["COOKIE"]) > 0:
if context.headers.get_header("COOKIE"):
# Right now just use the first Cookie header, will change later to use
# framework definition for cookies.
context.cookies = parse_cookies(context.headers["COOKIE"][0])
context.cookies = parse_cookies(context.headers.get_header("COOKIE"))
else:
context.cookies = {}

Expand Down
15 changes: 6 additions & 9 deletions aikido_zen/context/asgi/normalize_asgi_headers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
"""Mainly exports normalize_asgi_headers"""

from aikido_zen.helpers.headers import Headers

def normalize_asgi_headers(headers):

def normalize_asgi_headers(headers) -> Headers:
"""
Normalizes headers provided by ASGI :
Decodes them, uppercase and underscore keys
"""
parsed_headers = {}
result = Headers()
for k, v in headers:
# Normalizing key : decoding, removing dashes and uppercase
key_without_dashes = k.decode("utf-8").replace("-", "_")
key_normalized = key_without_dashes.upper()
if not key_normalized in parsed_headers:
parsed_headers[key_normalized] = list()
parsed_headers[key_normalized].append(v.decode("utf-8"))
return parsed_headers
result.store_header(k.decode("utf-8"), v.decode("utf-8"))
return result
34 changes: 17 additions & 17 deletions aikido_zen/context/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def test_wsgi_context_1():
"source": "django",
"method": "POST",
"headers": {
"HEADER_1": "header 1 value",
"HEADER_2": "Header 2 value",
"COOKIE": "sessionId=abc123xyz456;",
"HOST": "example.com",
"CONTENT_TYPE": "application/x-www-form-urlencoded",
"HEADER_1": ["header 1 value"],
"HEADER_2": ["Header 2 value"],
"COOKIE": ["sessionId=abc123xyz456;"],
"HOST": ["example.com"],
"CONTENT_TYPE": ["application/x-www-form-urlencoded"],
},
"cookies": {"sessionId": "abc123xyz456"},
"url": "https://example.com/hello",
Expand All @@ -81,12 +81,12 @@ def test_wsgi_context_2():
"source": "flask",
"method": "GET",
"headers": {
"HEADER_1": "header 1 value",
"HEADER_2": "Header 2 value",
"COOKIE": "sessionId=abc123xyz456;",
"HOST": "localhost:8080",
"CONTENT_TYPE": "application/json",
"USER_AGENT": "Mozilla/5.0",
"HEADER_1": ["header 1 value"],
"HEADER_2": ["Header 2 value"],
"COOKIE": ["sessionId=abc123xyz456;"],
"HOST": ["localhost:8080"],
"CONTENT_TYPE": ["application/json"],
"USER_AGENT": ["Mozilla/5.0"],
},
"cookies": {"sessionId": "abc123xyz456"},
"url": "http://localhost:8080/hello",
Expand Down Expand Up @@ -129,12 +129,12 @@ def test_context_is_picklable(mocker):
assert unpickled_obj.url == "http://localhost:8080/hello"
assert unpickled_obj.body == 123
assert unpickled_obj.headers == {
"HEADER_1": "header 1 value",
"HEADER_2": "Header 2 value",
"COOKIE": "sessionId=abc123xyz456;",
"HOST": "localhost:8080",
"CONTENT_TYPE": "application/json",
"USER_AGENT": "Mozilla/5.0",
"HEADER_1": ["header 1 value"],
"HEADER_2": ["Header 2 value"],
"COOKIE": ["sessionId=abc123xyz456;"],
"HOST": ["localhost:8080"],
"CONTENT_TYPE": ["application/json"],
"USER_AGENT": ["Mozilla/5.0"],
}
assert unpickled_obj.query == {"user": ["JohnDoe"], "age": ["30", "35"]}
assert unpickled_obj.cookies == {"sessionId": "abc123xyz456"}
Expand Down
6 changes: 3 additions & 3 deletions aikido_zen/context/wsgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def set_wsgi_attributes_on_context(context, environ):

context.method = environ["REQUEST_METHOD"]
context.headers = extract_wsgi_headers(environ)
if "COOKIE" in context.headers:
context.cookies = parse_cookies(context.headers["COOKIE"])
if context.headers.get_header("COOKIE"):
context.cookies = parse_cookies(context.headers.get_header("COOKIE"))
else:
context.cookies = {}
context.url = build_url_from_wsgi(environ)
Expand All @@ -29,4 +29,4 @@ def set_wsgi_attributes_on_context(context, environ):

# Content type is generally not included as a header, do include this as a header to simplify :
if "CONTENT_TYPE" in environ:
context.headers["CONTENT_TYPE"] = environ["CONTENT_TYPE"]
context.headers.store_header("CONTENT_TYPE", environ["CONTENT_TYPE"])
9 changes: 6 additions & 3 deletions aikido_zen/context/wsgi/extract_wsgi_headers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Exports function extract_wsgi_headers"""

from aikido_zen.helpers.headers import Headers

def extract_wsgi_headers(request):

def extract_wsgi_headers(request) -> Headers:
"""Extracts WSGI headers which start with HTTP_ from request dict"""
headers = {}
headers = Headers()
for key, value in request.items():
if key.startswith("HTTP_"):
# Remove the 'HTTP_' prefix and store in the headers dictionary
headers[key[5:]] = value
header_key = key[5:]
headers.store_header(header_key, value)
return headers
10 changes: 5 additions & 5 deletions aikido_zen/context/wsgi/extract_wsgi_headers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

def test_extract_wsgi_headers_single_header():
request = {"REQUEST_METHOD": "GET", "HTTP_USER_AGENT": "Mozilla/5.0"}
expected = {"USER_AGENT": "Mozilla/5.0"}
expected = {"USER_AGENT": ["Mozilla/5.0"]}
assert extract_wsgi_headers(request) == expected


Expand All @@ -16,9 +16,9 @@ def test_extract_wsgi_headers_multiple_headers():
"HTTP_CONTENT_TYPE": "application/json",
}
expected = {
"HOST": "example.com",
"ACCEPT": "text/html",
"CONTENT_TYPE": "application/json",
"HOST": ["example.com"],
"ACCEPT": ["text/html"],
"CONTENT_TYPE": ["application/json"],
}
assert extract_wsgi_headers(request) == expected

Expand All @@ -42,5 +42,5 @@ def test_extract_wsgi_headers_mixed_headers():
"HTTP_ACCEPT_LANGUAGE": "en-US,en;q=0.5",
"OTHER_HEADER": "value",
}
expected = {"USER_AGENT": "Mozilla/5.0", "ACCEPT_LANGUAGE": "en-US,en;q=0.5"}
expected = {"USER_AGENT": ["Mozilla/5.0"], "ACCEPT_LANGUAGE": ["en-US,en;q=0.5"]}
assert extract_wsgi_headers(request) == expected
Loading
Loading