Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flask_cors/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def wrapped_function(*args, **kwargs):

if options.get('automatic_options') and request.method == 'OPTIONS':
resp = current_app.make_default_options_response()
# if decorator does not have methods, then use the allowed methods
# from the view function. view function methods are preferred to
# decorator but not to app level options.
if 'methods' not in _options and resp.headers.get('allow'):
options['methods'] = resp.headers.get('allow')
else:
resp = make_response(f(*args, **kwargs))

Expand Down
5 changes: 5 additions & 0 deletions flask_cors/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ def cors_after_request(resp):
normalized_path = unquote_plus(request.path)
for res_regex, res_options in resources:
if try_match(normalized_path, res_regex):
# if we reach here, the decorator has not evaluated the request
# so use the allowed methods from the view function because methods
# from view function are to be preferred to app level methods
if resp.headers is not None and resp.headers.get('allow'):
res_options['methods'] = resp.headers.get('allow')
LOG.debug("Request to '%s' matches CORS resource '%s'. Using options: %s",
request.path, get_regexp_pattern(res_regex), res_options)
set_cors_headers(resp, res_options)
Expand Down
46 changes: 44 additions & 2 deletions tests/decorator/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def defaults():
def test_get():
return 'Only allow POST'

@self.app.route('/test_methods_view', methods=['PATCH'])
@cross_origin()
def test_methods_view():
return 'Only allow PATCH'

@self.app.route('/test_methods_view_and_defined', methods=['POST', 'DELETE'])
@cross_origin(methods=['DELETE'])
def test_methods_view_and_defined():
return 'Only allow POST and DELETE'

def test_defaults(self):
''' Access-Control-Allow-Methods headers should only be returned
if the client makes an OPTIONS request.
Expand All @@ -38,8 +48,7 @@ def test_defaults(self):
self.assertFalse(ACL_METHODS in self.get('/defaults', origin='www.example.com').headers)
self.assertFalse(ACL_METHODS in self.head('/defaults', origin='www.example.com').headers)
res = self.preflight('/defaults', 'POST', origin='www.example.com')
for method in ALL_METHODS:
self.assertTrue(method in res.headers.get(ACL_METHODS))
self.assertIsNone(res.headers.get(ACL_METHODS))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is backward incompatible. Earlier if the cross_origin decorator had no methods specified all the methods from the default or app level configuration would be used but now the methods will be restricted to what the view supports.


def test_methods_defined(self):
''' If the methods parameter is defined, it should override the default
Expand All @@ -57,5 +66,38 @@ def test_methods_defined(self):
res = self.get('/test_methods_defined', origin='www.example.com')
self.assertFalse(ACL_METHODS in res.headers)

def test_methods_view(self):
''' If the methods parameter is defined in view function, it should override the default
methods defined by the user.
'''
self.assertFalse(ACL_METHODS in self.get('/test_methods_view').headers)
self.assertFalse(ACL_METHODS in self.head('/test_methods_view').headers)

res = self.preflight('/test_methods_view', 'PATCH', origin='www.example.com')
self.assertTrue('PATCH' in res.headers.get(ACL_METHODS))

res = self.preflight('/test_methods_view', 'POST', origin='www.example.com')
self.assertFalse(ACL_METHODS in res.headers)

res = self.get('/test_methods_view', origin='www.example.com')
self.assertFalse(ACL_METHODS in res.headers)

def test_methods_view_and_defined(self):
''' If the methods parameter is defined in cross_origin decorator and view
function, the decorator methods should be used.
'''
self.assertFalse(ACL_METHODS in self.get('/test_methods_view_and_defined').headers)
self.assertFalse(ACL_METHODS in self.head('/test_methods_view_and_defined').headers)

res = self.preflight('/test_methods_view_and_defined', 'DELETE', origin='www.example.com')
self.assertTrue('DELETE' in res.headers.get(ACL_METHODS))
self.assertTrue('POST' not in res.headers.get(ACL_METHODS))

res = self.preflight('/test_methods_view_and_defined', 'POST', origin='www.example.com')
self.assertFalse(ACL_METHODS in res.headers)

res = self.get('/test_methods_view_and_defined', origin='www.example.com')
self.assertFalse(ACL_METHODS in res.headers)

if __name__ == "__main__":
unittest.main()
23 changes: 23 additions & 0 deletions tests/extension/test_app_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,5 +378,28 @@ def index():
self.assertEqual(resp.status_code, 200)


class AppExtensionPreflight(FlaskCorsTestCase):
def test_preflight(self):
''' Ensure that view function methods override app level defaults '''
self.app = Flask(__name__)
CORS(self.app)

@self.app.route('/')
def index():
return 'Welcome'

@self.app.route('/test', methods=['POST'])
def index2():
return 'Welcome 2'

res = self.preflight('/', 'GET', origin='www.example.com')
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is also backward incompatible in a way that instead of default or level configuration, methods from view function are used.

self.assertTrue('POST' not in res.headers.get(ACL_METHODS))
self.assertTrue('GET' in res.headers.get(ACL_METHODS))

res = self.preflight('/test', 'POST', origin='www.example.com')
self.assertTrue('POST' in res.headers.get(ACL_METHODS))
self.assertTrue('GET' not in res.headers.get(ACL_METHODS))


if __name__ == "__main__":
unittest.main()