Skip to content

Commit 6f66f0f

Browse files
authored
Allow passing in a single string location to the locations kwarg (#402)
Allow locations kwarg for jwt_required() to be a string
1 parent 228822d commit 6f66f0f

File tree

5 files changed

+54
-20
lines changed

5 files changed

+54
-20
lines changed

docs/add_custom_data_claims.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Storing Data in Access Tokens
2-
=============================
1+
Storing Additional Data in JWTs
2+
===============================
33
You may want to store additional information in the access token which you could
44
later access in the protected views. This can be done using the ``additional_claims``
55
argument with the :func:`~flask_jwt_extended.create_access_token` or

examples/additional_data_in_access_token.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def login():
3030
# In a protected view, get the claims you added to the jwt with the
3131
# get_jwt() method
3232
@app.route("/protected", methods=["GET"])
33-
@jwt_required
33+
@jwt_required()
3434
def protected():
3535
claims = get_jwt()
3636
return jsonify(foo=claims["foo"])

examples/loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def login():
2828

2929

3030
@app.route("/protected", methods=["GET"])
31-
@jwt_required
31+
@jwt_required()
3232
def protected():
3333
return jsonify(hello="world")
3434

flask_jwt_extended/view_decorators.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=
5050
If ``True``, require a refresh JWT to be verified.
5151
5252
:param locations:
53-
A list of locations to look for the JWT in this request, for example:
54-
``['headers', 'cookies']``. Defaluts to ``None`` which indicates that JWTs
55-
will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION``
56-
configuration option.
53+
A location or list of locations to look for the JWT in this request, for
54+
example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None``
55+
which indicates that JWTs will be looked for in the locations defined by the
56+
``JWT_TOKEN_LOCATION`` configuration option.
5757
"""
5858
if request.method in config.exempt_methods:
5959
return
@@ -103,10 +103,10 @@ def jwt_required(optional=False, fresh=False, refresh=False, locations=None):
103103
requires an access JWT to access this endpoint. Defaults to ``False``.
104104
105105
:param locations:
106-
A list of locations to look for the JWT in this request, for example:
107-
``['headers', 'cookies']``. Defaluts to ``None`` which indicates that JWTs
108-
will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION``
109-
configuration option.
106+
A location or list of locations to look for the JWT in this request, for
107+
example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None``
108+
which indicates that JWTs will be looked for in the locations defined by the
109+
``JWT_TOKEN_LOCATION`` configuration option.
110110
"""
111111

112112
def wrapper(fn):
@@ -227,26 +227,28 @@ def _decode_jwt_from_json(refresh):
227227

228228

229229
def _decode_jwt_from_request(locations, fresh, refresh=False):
230-
# All the places we can get a JWT from in this request
231-
get_encoded_token_functions = []
230+
# Figure out what locations to look for the JWT in this request
231+
if isinstance(locations, str):
232+
locations = [locations]
232233

233-
# Get locations in the order specified by the decorator or JWT_TOKEN_LOCATION
234-
# configuration.
235234
if not locations:
236235
locations = config.token_location
237236

238-
# Add the functions in the order specified by locations.
237+
# Get the decode functions in the order specified by locations.
238+
get_encoded_token_functions = []
239239
for location in locations:
240240
if location == "cookies":
241241
get_encoded_token_functions.append(
242242
lambda: _decode_jwt_from_cookies(refresh)
243243
)
244-
if location == "query_string":
244+
elif location == "query_string":
245245
get_encoded_token_functions.append(_decode_jwt_from_query_string)
246-
if location == "headers":
246+
elif location == "headers":
247247
get_encoded_token_functions.append(_decode_jwt_from_headers)
248-
if location == "json":
248+
elif location == "json":
249249
get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh))
250+
else:
251+
raise RuntimeError(f"'{location}' is not a valid location")
250252

251253
# Try to find the token from one of these locations. It only needs to exist
252254
# in one place to be valid (not every location).

tests/test_view_decorators.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,38 @@ def test_jwt_optional(app, delta_func):
187187
assert response.get_json() == {"msg": "Token has expired"}
188188

189189

190+
def test_override_jwt_location(app):
191+
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
192+
193+
@app.route("/protected_other")
194+
@jwt_required(locations="headers")
195+
def protected_other():
196+
return jsonify(foo="bar")
197+
198+
@app.route("/protected_invalid")
199+
@jwt_required(locations="INVALID_LOCATION")
200+
def protected_invalid():
201+
return jsonify(foo="bar")
202+
203+
test_client = app.test_client()
204+
with app.test_request_context():
205+
access_token = create_access_token("username")
206+
207+
url = "/protected_other"
208+
response = test_client.get(url, headers=make_headers(access_token))
209+
assert response.get_json() == {"foo": "bar"}
210+
assert response.status_code == 200
211+
212+
url = "/protected"
213+
response = test_client.get(url, headers=make_headers(access_token))
214+
assert response.status_code == 401
215+
assert response.get_json() == {"msg": 'Missing cookie "access_token_cookie"'}
216+
217+
url = "/protected_invalid"
218+
response = test_client.get(url, headers=make_headers(access_token))
219+
assert response.status_code == 500
220+
221+
190222
def test_invalid_jwt(app):
191223
url = "/protected"
192224
jwtM = get_jwt_manager(app)

0 commit comments

Comments
 (0)