Skip to content

Commit 228822d

Browse files
authored
Remove hardcoded assumption that the only JWT type supported other than "refresh" is "access" (#401)
Cloudflare Teams JWT auth for example, sets a token with a value of "app".
1 parent 715f9d5 commit 228822d

File tree

5 files changed

+35
-37
lines changed

5 files changed

+35
-37
lines changed

flask_jwt_extended/internal_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ def user_lookup(*args, **kwargs):
2525
return jwt_manager._user_lookup_callback(*args, **kwargs)
2626

2727

28-
def verify_token_type(decoded_token, expected_type):
29-
if decoded_token["type"] != expected_type:
30-
raise WrongTokenError("Only {} tokens are allowed".format(expected_type))
28+
def verify_token_type(decoded_token, refresh):
29+
if not refresh and decoded_token["type"] == "refresh":
30+
raise WrongTokenError("Only non-refresh tokens are allowed")
31+
elif refresh and decoded_token["type"] != "refresh":
32+
raise WrongTokenError("Only refresh tokens are allowed")
3133

3234

33-
def verify_token_not_blocklisted(jwt_header, jwt_data, request_type):
35+
def verify_token_not_blocklisted(jwt_header, jwt_data):
3436
jwt_manager = get_jwt_manager()
3537
if jwt_manager._token_in_blocklist_callback(jwt_header, jwt_data):
3638
raise RevokedTokenError(jwt_header, jwt_data)

flask_jwt_extended/tokens.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,6 @@ def _decode_jwt(
9898
if "type" not in decoded_token:
9999
decoded_token["type"] = "access"
100100

101-
if decoded_token["type"] not in ("access", "refresh"):
102-
raise JWTDecodeError("Invalid token type: {}".format(decoded_token["type"]))
103-
104101
if "fresh" not in decoded_token:
105102
decoded_token["fresh"] = False
106103

flask_jwt_extended/view_decorators.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=
4747
Defaults to ``False``.
4848
4949
:param refresh:
50-
If ``True``, require a refresh JWT to be verified. If ``False`` require an access
51-
JWT to be verified. Defaults to ``False``.
50+
If ``True``, require a refresh JWT to be verified.
5251
5352
:param locations:
5453
A list of locations to look for the JWT in this request, for example:
@@ -61,9 +60,11 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=
6160

6261
try:
6362
if refresh:
64-
jwt_data, jwt_header = _decode_jwt_from_request("refresh", locations, fresh)
63+
jwt_data, jwt_header = _decode_jwt_from_request(
64+
locations, fresh, refresh=True
65+
)
6566
else:
66-
jwt_data, jwt_header = _decode_jwt_from_request("access", locations, fresh)
67+
jwt_data, jwt_header = _decode_jwt_from_request(locations, fresh)
6768
except (NoAuthorizationError, InvalidHeaderError):
6869
if not optional:
6970
raise
@@ -170,15 +171,15 @@ def _decode_jwt_from_headers():
170171
return encoded_token, None
171172

172173

173-
def _decode_jwt_from_cookies(token_type):
174-
if token_type == "access":
175-
cookie_key = config.access_cookie_name
176-
csrf_header_key = config.access_csrf_header_name
177-
csrf_field_key = config.access_csrf_field_name
178-
else:
174+
def _decode_jwt_from_cookies(refresh):
175+
if refresh:
179176
cookie_key = config.refresh_cookie_name
180177
csrf_header_key = config.refresh_csrf_header_name
181178
csrf_field_key = config.refresh_csrf_field_name
179+
else:
180+
cookie_key = config.access_cookie_name
181+
csrf_header_key = config.access_csrf_header_name
182+
csrf_field_key = config.access_csrf_field_name
182183

183184
encoded_token = request.cookies.get(cookie_key)
184185
if not encoded_token:
@@ -205,15 +206,15 @@ def _decode_jwt_from_query_string():
205206
return encoded_token, None
206207

207208

208-
def _decode_jwt_from_json(token_type):
209+
def _decode_jwt_from_json(refresh):
209210
content_type = request.content_type or ""
210211
if not content_type.startswith("application/json"):
211212
raise NoAuthorizationError("Invalid content-type. Must be application/json.")
212213

213-
if token_type == "access":
214-
token_key = config.json_key
215-
else:
214+
if refresh:
216215
token_key = config.refresh_json_key
216+
else:
217+
token_key = config.json_key
217218

218219
try:
219220
encoded_token = request.json.get(token_key, None)
@@ -225,7 +226,7 @@ def _decode_jwt_from_json(token_type):
225226
return encoded_token, None
226227

227228

228-
def _decode_jwt_from_request(token_type, locations, fresh):
229+
def _decode_jwt_from_request(locations, fresh, refresh=False):
229230
# All the places we can get a JWT from in this request
230231
get_encoded_token_functions = []
231232

@@ -238,16 +239,14 @@ def _decode_jwt_from_request(token_type, locations, fresh):
238239
for location in locations:
239240
if location == "cookies":
240241
get_encoded_token_functions.append(
241-
lambda: _decode_jwt_from_cookies(token_type)
242+
lambda: _decode_jwt_from_cookies(refresh)
242243
)
243244
if location == "query_string":
244245
get_encoded_token_functions.append(_decode_jwt_from_query_string)
245246
if location == "headers":
246247
get_encoded_token_functions.append(_decode_jwt_from_headers)
247248
if location == "json":
248-
get_encoded_token_functions.append(
249-
lambda: _decode_jwt_from_json(token_type)
250-
)
249+
get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh))
251250

252251
# Try to find the token from one of these locations. It only needs to exist
253252
# in one place to be valid (not every location).
@@ -277,10 +276,10 @@ def _decode_jwt_from_request(token_type, locations, fresh):
277276
raise NoAuthorizationError(errors[0])
278277

279278
# Additional verifications provided by this extension
280-
verify_token_type(decoded_token, expected_type=token_type)
279+
verify_token_type(decoded_token, refresh)
281280
if fresh:
282281
_verify_token_is_fresh(jwt_header, decoded_token)
283-
verify_token_not_blocklisted(jwt_header, decoded_token, token_type)
282+
verify_token_not_blocklisted(jwt_header, decoded_token)
284283
custom_verification_for_token(jwt_header, decoded_token)
285284

286285
return decoded_token, jwt_header

tests/test_decode_tokens.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ def test_default_decode_token_values(app, default_access_token):
7070
assert decoded["fresh"] is False
7171

7272

73-
def test_bad_token_type(app, default_access_token):
74-
default_access_token["type"] = "banana"
75-
bad_type_token = encode_token(app, default_access_token)
73+
def test_supports_decoding_other_token_types(app, default_access_token):
74+
default_access_token["type"] = "app"
75+
other_token = encode_token(app, default_access_token)
7676

77-
with pytest.raises(JWTDecodeError):
78-
with app.test_request_context():
79-
decode_token(bad_type_token)
77+
with app.test_request_context():
78+
decoded = decode_token(other_token)
79+
assert decoded["type"] == "app"
8080

8181

8282
def test_encode_decode_audience(app):

tests/test_view_decorators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_jwt_required(app):
7272
# Test refresh token access to jwt_required
7373
response = test_client.get(url, headers=make_headers(refresh_token))
7474
assert response.status_code == 422
75-
assert response.get_json() == {"msg": "Only access tokens are allowed"}
75+
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"}
7676

7777

7878
def test_fresh_jwt_required(app):
@@ -113,7 +113,7 @@ def test_fresh_jwt_required(app):
113113

114114
response = test_client.get(url, headers=make_headers(refresh_token))
115115
assert response.status_code == 422
116-
assert response.get_json() == {"msg": "Only access tokens are allowed"}
116+
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"}
117117

118118
# Test with custom response
119119
@jwtM.needs_fresh_token_loader
@@ -176,7 +176,7 @@ def test_jwt_optional(app, delta_func):
176176

177177
response = test_client.get(url, headers=make_headers(refresh_token))
178178
assert response.status_code == 422
179-
assert response.get_json() == {"msg": "Only access tokens are allowed"}
179+
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"}
180180

181181
response = test_client.get(url, headers=None)
182182
assert response.status_code == 200

0 commit comments

Comments
 (0)