1
+ import logging
1
2
from typing import (
2
3
MutableMapping ,
3
4
Mapping ,
13
14
from aiohttp import web
14
15
from aiohttp .hdrs import METH_ANY , METH_ALL
15
16
from jinja2 import Template
16
-
17
17
try :
18
18
import ujson as json
19
19
except ImportError : # pragma: no cover
20
20
import json
21
21
22
+ from .validation import validate_decorator
23
+
22
24
23
25
SWAGGER_TEMPLATE = abspath (join (dirname (__file__ ), ".." , "templates" ))
24
26
25
27
26
- def _extract_swagger_docs (end_point_doc , method = "get" ):
27
- # Find Swagger start point in doc
28
+ def _extract_swagger_docs (end_point_doc : str ) -> Mapping :
29
+ """
30
+ Find Swagger start point in doc.
31
+ """
28
32
end_point_swagger_start = 0
29
33
for i , doc_line in enumerate (end_point_doc ):
30
34
if "---" in doc_line :
@@ -42,7 +46,7 @@ def _extract_swagger_docs(end_point_doc, method="get"):
42
46
"from docstring ⚠" ,
43
47
"tags" : ["Invalid Swagger" ]
44
48
}
45
- return { method : end_point_swagger_doc }
49
+ return end_point_swagger_doc
46
50
47
51
48
52
def _build_doc_from_func_doc (route ):
@@ -58,16 +62,14 @@ def _build_doc_from_func_doc(route):
58
62
method = getattr (route .handler , method_name )
59
63
if method .__doc__ is not None and "---" in method .__doc__ :
60
64
end_point_doc = method .__doc__ .splitlines ()
61
- out .update (
62
- _extract_swagger_docs (end_point_doc , method = method_name ))
65
+ out [method_name ] = _extract_swagger_docs (end_point_doc )
63
66
64
67
else :
65
68
try :
66
69
end_point_doc = route .handler .__doc__ .splitlines ()
67
70
except AttributeError :
68
71
return {}
69
- out .update (_extract_swagger_docs (
70
- end_point_doc , method = route .method .lower ()))
72
+ out [route .method .lower ()] = _extract_swagger_docs (end_point_doc )
71
73
return out
72
74
73
75
@@ -150,7 +152,49 @@ def load_doc_from_yaml_file(doc_path: str) -> MutableMapping:
150
152
return yaml .load (open (doc_path , "r" ).read ())
151
153
152
154
155
+ def add_swagger_validation (app , swagger_info : Mapping ):
156
+ for route in app .router .routes ():
157
+ method = route .method .lower ()
158
+ handler = route .handler
159
+ url_info = route .get_info ()
160
+ url = url_info .get ('path' ) or url_info .get ('formatter' )
161
+
162
+ if method != '*' :
163
+ swagger_endpoint_info_for_method = \
164
+ swagger_info ['paths' ].get (url , {}).get (method )
165
+ swagger_endpoint_info = \
166
+ {method : swagger_endpoint_info_for_method } if \
167
+ swagger_endpoint_info_for_method is not None else {}
168
+ else :
169
+ # all methods
170
+ swagger_endpoint_info = swagger_info ['paths' ].get (url , {})
171
+ for method , info in swagger_endpoint_info .items ():
172
+ logging .debug (
173
+ 'Added validation for method: {}. Path: {}' .
174
+ format (method .upper (), url )
175
+ )
176
+ if issubclass (handler , web .View ) and route .method == METH_ANY :
177
+ # whole class validation
178
+ should_be_validated = getattr (handler , 'validation' , False )
179
+ cls_method = getattr (handler , method , None )
180
+ if cls_method is not None :
181
+ if not should_be_validated :
182
+ # method validation
183
+ should_be_validated = \
184
+ getattr (handler , 'validation' , False )
185
+ if should_be_validated :
186
+ new_cls_method = \
187
+ validate_decorator (swagger_info , info )(cls_method )
188
+ setattr (handler , method , new_cls_method )
189
+ else :
190
+ should_be_validated = getattr (handler , 'validation' , False )
191
+ if should_be_validated :
192
+ route ._handler = \
193
+ validate_decorator (swagger_info , info )(handler )
194
+
195
+
153
196
__all__ = (
154
197
"generate_doc_from_each_end_point" ,
155
- "load_doc_from_yaml_file"
198
+ "load_doc_from_yaml_file" ,
199
+ "add_swagger_validation" ,
156
200
)
0 commit comments