Skip to content

Commit db0273b

Browse files
committed
Added requests validation based on swagger schema.
1 parent be39d48 commit db0273b

File tree

13 files changed

+1157
-30
lines changed

13 files changed

+1157
-30
lines changed

aiohttp_swagger/__init__.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,21 @@
1212
generate_doc_from_each_end_point,
1313
load_doc_from_yaml_file,
1414
swagger_path,
15+
swagger_validation,
16+
add_swagger_validation,
1517
)
1618

1719
try:
1820
import ujson as json
1921
except ImportError:
2022
import json
2123

24+
__all__ = (
25+
"setup_swagger",
26+
"swagger_path",
27+
"swagger_validation",
28+
)
29+
2230

2331
@asyncio.coroutine
2432
def _swagger_home(request):
@@ -89,7 +97,7 @@ def setup_swagger(app: web.Application,
8997
)
9098

9199
if swagger_validate_schema:
92-
pass
100+
add_swagger_validation(app, swagger_info)
93101

94102
swagger_info = json.dumps(swagger_info)
95103

@@ -119,12 +127,9 @@ def setup_swagger(app: web.Application,
119127
with open(join(STATIC_PATH, "index.html"), "r") as f:
120128
app["SWAGGER_TEMPLATE_CONTENT"] = (
121129
f.read()
122-
.replace("##SWAGGER_CONFIG##", '/{}{}'.
130+
.replace("##SWAGGER_CONFIG##", '{}{}'.
123131
format(api_base_url.lstrip('/'), _swagger_def_url))
124-
.replace("##STATIC_PATH##", '/{}{}'.
132+
.replace("##STATIC_PATH##", '{}{}'.
125133
format(api_base_url.lstrip('/'), statics_path))
126134
.replace("##SWAGGER_VALIDATOR_URL##", swagger_validator_url)
127135
)
128-
129-
130-
__all__ = ("setup_swagger", "swagger_path")

aiohttp_swagger/helpers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .builders import * # noqa
22
from .decorators import * # noqa
3+
from .validation import * # noqa

aiohttp_swagger/helpers/builders.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import (
23
MutableMapping,
34
Mapping,
@@ -13,18 +14,21 @@
1314
from aiohttp import web
1415
from aiohttp.hdrs import METH_ANY, METH_ALL
1516
from jinja2 import Template
16-
1717
try:
1818
import ujson as json
1919
except ImportError: # pragma: no cover
2020
import json
2121

22+
from .validation import validate_decorator
23+
2224

2325
SWAGGER_TEMPLATE = abspath(join(dirname(__file__), "..", "templates"))
2426

2527

26-
def _extract_swagger_docs(end_point_doc, method="get"):
27-
# Find Swagger start point in doc
28+
def _extract_swagger_docs(end_point_doc: str) -> Mapping:
29+
"""
30+
Find Swagger start point in doc.
31+
"""
2832
end_point_swagger_start = 0
2933
for i, doc_line in enumerate(end_point_doc):
3034
if "---" in doc_line:
@@ -42,7 +46,7 @@ def _extract_swagger_docs(end_point_doc, method="get"):
4246
"from docstring ⚠",
4347
"tags": ["Invalid Swagger"]
4448
}
45-
return {method: end_point_swagger_doc}
49+
return end_point_swagger_doc
4650

4751

4852
def _build_doc_from_func_doc(route):
@@ -58,16 +62,14 @@ def _build_doc_from_func_doc(route):
5862
method = getattr(route.handler, method_name)
5963
if method.__doc__ is not None and "---" in method.__doc__:
6064
end_point_doc = method.__doc__.splitlines()
61-
out.update(
62-
_extract_swagger_docs(end_point_doc, method=method_name))
65+
out[method_name] = _extract_swagger_docs(end_point_doc)
6366

6467
else:
6568
try:
6669
end_point_doc = route.handler.__doc__.splitlines()
6770
except AttributeError:
6871
return {}
69-
out.update(_extract_swagger_docs(
70-
end_point_doc, method=route.method.lower()))
72+
out[route.method.lower()] = _extract_swagger_docs(end_point_doc)
7173
return out
7274

7375

@@ -150,7 +152,49 @@ def load_doc_from_yaml_file(doc_path: str) -> MutableMapping:
150152
return yaml.load(open(doc_path, "r").read())
151153

152154

155+
def add_swagger_validation(app, swagger_info: Mapping):
156+
for route in app.router.routes():
157+
method = route.method.lower()
158+
handler = route.handler
159+
url_info = route.get_info()
160+
url = url_info.get('path') or url_info.get('formatter')
161+
162+
if method != '*':
163+
swagger_endpoint_info_for_method = \
164+
swagger_info['paths'].get(url, {}).get(method)
165+
swagger_endpoint_info = \
166+
{method: swagger_endpoint_info_for_method} if \
167+
swagger_endpoint_info_for_method is not None else {}
168+
else:
169+
# all methods
170+
swagger_endpoint_info = swagger_info['paths'].get(url, {})
171+
for method, info in swagger_endpoint_info.items():
172+
logging.debug(
173+
'Added validation for method: {}. Path: {}'.
174+
format(method.upper(), url)
175+
)
176+
if issubclass(handler, web.View) and route.method == METH_ANY:
177+
# whole class validation
178+
should_be_validated = getattr(handler, 'validation', False)
179+
cls_method = getattr(handler, method, None)
180+
if cls_method is not None:
181+
if not should_be_validated:
182+
# method validation
183+
should_be_validated = \
184+
getattr(handler, 'validation', False)
185+
if should_be_validated:
186+
new_cls_method = \
187+
validate_decorator(swagger_info, info)(cls_method)
188+
setattr(handler, method, new_cls_method)
189+
else:
190+
should_be_validated = getattr(handler, 'validation', False)
191+
if should_be_validated:
192+
route._handler = \
193+
validate_decorator(swagger_info, info)(handler)
194+
195+
153196
__all__ = (
154197
"generate_doc_from_each_end_point",
155-
"load_doc_from_yaml_file"
198+
"load_doc_from_yaml_file",
199+
"add_swagger_validation",
156200
)

aiohttp_swagger/helpers/decorators.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,27 @@
1-
class swagger_path(object):
1+
from functools import partial
2+
from inspect import isfunction, isclass
3+
4+
__all__ = (
5+
'swagger_path',
6+
'swagger_validation',
7+
)
8+
9+
10+
class swagger_path:
11+
212
def __init__(self, swagger_file):
313
self.swagger_file = swagger_file
414

515
def __call__(self, f):
616
f.swagger_file = self.swagger_file
717
return f
18+
19+
20+
def swagger_validation(func=None, *, validation=True):
21+
22+
if func is None or not (isfunction(func) or isclass(func)):
23+
validation = func
24+
return partial(swagger_validation, validation=validation)
25+
26+
func.validation = validation
27+
return func

0 commit comments

Comments
 (0)