Skip to content

Commit aaca2cf

Browse files
committed
Adding ssl_verify_flags_config argument for ssl connection configuration
1 parent 8403ddc commit aaca2cf

File tree

9 files changed

+308
-6
lines changed

9 files changed

+308
-6
lines changed

redis/asyncio/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,11 @@
8181
)
8282

8383
if TYPE_CHECKING and SSL_AVAILABLE:
84-
from ssl import TLSVersion, VerifyMode
84+
from ssl import TLSVersion, VerifyFlags, VerifyMode
8585
else:
8686
TLSVersion = None
8787
VerifyMode = None
88+
VerifyFlags = None
8889

8990
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
9091
_KeyT = TypeVar("_KeyT", bound=KeyT)
@@ -238,6 +239,7 @@ def __init__(
238239
ssl_keyfile: Optional[str] = None,
239240
ssl_certfile: Optional[str] = None,
240241
ssl_cert_reqs: Union[str, VerifyMode] = "required",
242+
ssl_verify_flags_config: Optional[List[Tuple[VerifyFlags, bool]]] = None,
241243
ssl_ca_certs: Optional[str] = None,
242244
ssl_ca_data: Optional[str] = None,
243245
ssl_check_hostname: bool = True,
@@ -347,6 +349,7 @@ def __init__(
347349
"ssl_keyfile": ssl_keyfile,
348350
"ssl_certfile": ssl_certfile,
349351
"ssl_cert_reqs": ssl_cert_reqs,
352+
"ssl_verify_flags_config": ssl_verify_flags_config,
350353
"ssl_ca_certs": ssl_ca_certs,
351354
"ssl_ca_data": ssl_ca_data,
352355
"ssl_check_hostname": ssl_check_hostname,

redis/asyncio/cluster.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,11 @@
8686
)
8787

8888
if SSL_AVAILABLE:
89-
from ssl import TLSVersion, VerifyMode
89+
from ssl import TLSVersion, VerifyFlags, VerifyMode
9090
else:
9191
TLSVersion = None
9292
VerifyMode = None
93+
VerifyFlags = None
9394

9495
TargetNodesT = TypeVar(
9596
"TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
@@ -299,6 +300,7 @@ def __init__(
299300
ssl_ca_certs: Optional[str] = None,
300301
ssl_ca_data: Optional[str] = None,
301302
ssl_cert_reqs: Union[str, VerifyMode] = "required",
303+
ssl_verify_flags_config: Optional[List[Tuple[VerifyFlags, bool]]] = None,
302304
ssl_certfile: Optional[str] = None,
303305
ssl_check_hostname: bool = True,
304306
ssl_keyfile: Optional[str] = None,
@@ -358,6 +360,7 @@ def __init__(
358360
"ssl_ca_certs": ssl_ca_certs,
359361
"ssl_ca_data": ssl_ca_data,
360362
"ssl_cert_reqs": ssl_cert_reqs,
363+
"ssl_verify_flags_config": ssl_verify_flags_config,
361364
"ssl_certfile": ssl_certfile,
362365
"ssl_check_hostname": ssl_check_hostname,
363366
"ssl_keyfile": ssl_keyfile,

redis/asyncio/connection.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import ast
12
import asyncio
23
import copy
34
import enum
45
import inspect
6+
import re
57
import socket
68
import sys
79
import warnings
@@ -30,11 +32,12 @@
3032

3133
if SSL_AVAILABLE:
3234
import ssl
33-
from ssl import SSLContext, TLSVersion
35+
from ssl import SSLContext, TLSVersion, VerifyFlags
3436
else:
3537
ssl = None
3638
TLSVersion = None
3739
SSLContext = None
40+
VerifyFlags = None
3841

3942
from ..auth.token import TokenInterface
4043
from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
@@ -793,6 +796,7 @@ def __init__(
793796
ssl_keyfile: Optional[str] = None,
794797
ssl_certfile: Optional[str] = None,
795798
ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
799+
ssl_verify_flags_config: Optional[List[Tuple["ssl.VerifyFlags", bool]]] = None,
796800
ssl_ca_certs: Optional[str] = None,
797801
ssl_ca_data: Optional[str] = None,
798802
ssl_check_hostname: bool = True,
@@ -807,6 +811,7 @@ def __init__(
807811
keyfile=ssl_keyfile,
808812
certfile=ssl_certfile,
809813
cert_reqs=ssl_cert_reqs,
814+
verify_flags_config=ssl_verify_flags_config,
810815
ca_certs=ssl_ca_certs,
811816
ca_data=ssl_ca_data,
812817
check_hostname=ssl_check_hostname,
@@ -832,6 +837,10 @@ def certfile(self):
832837
def cert_reqs(self):
833838
return self.ssl_context.cert_reqs
834839

840+
@property
841+
def verify_flags_config(self):
842+
return self.ssl_context.verify_flags_config
843+
835844
@property
836845
def ca_certs(self):
837846
return self.ssl_context.ca_certs
@@ -854,6 +863,7 @@ class RedisSSLContext:
854863
"keyfile",
855864
"certfile",
856865
"cert_reqs",
866+
"verify_flags_config",
857867
"ca_certs",
858868
"ca_data",
859869
"context",
@@ -867,6 +877,7 @@ def __init__(
867877
keyfile: Optional[str] = None,
868878
certfile: Optional[str] = None,
869879
cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
880+
verify_flags_config: Optional[List[Tuple[ssl.VerifyFlags, bool]]] = None,
870881
ca_certs: Optional[str] = None,
871882
ca_data: Optional[str] = None,
872883
check_hostname: bool = False,
@@ -892,6 +903,7 @@ def __init__(
892903
)
893904
cert_reqs = CERT_REQS[cert_reqs]
894905
self.cert_reqs = cert_reqs
906+
self.verify_flags_config = verify_flags_config
895907
self.ca_certs = ca_certs
896908
self.ca_data = ca_data
897909
self.check_hostname = (
@@ -906,6 +918,12 @@ def get(self) -> SSLContext:
906918
context = ssl.create_default_context()
907919
context.check_hostname = self.check_hostname
908920
context.verify_mode = self.cert_reqs
921+
if self.verify_flags_config:
922+
for flag, enabled in self.verify_flags_config:
923+
if enabled:
924+
context.options |= flag
925+
else:
926+
context.options &= ~flag
909927
if self.certfile and self.keyfile:
910928
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
911929
if self.ca_certs or self.ca_data:
@@ -1021,6 +1039,34 @@ def parse_url(url: str) -> ConnectKwargs:
10211039

10221040
if parsed.scheme == "rediss":
10231041
kwargs["connection_class"] = SSLConnection
1042+
1043+
if "ssl_verify_flags_config" in kwargs:
1044+
# flags are passed in as a string representation of a list,
1045+
# e.g. [(VERIFY_X509_STRICT, False), (VERIFY_X509_PARTIAL_CHAIN, True)]
1046+
# To parse it sucessfully, we need transform the flags to strings with quotes.
1047+
verify_flags_config_str = kwargs.pop("ssl_verify_flags_config")
1048+
# First wrap any VERIFY_* name in quotes
1049+
verify_flags_config_str = re.sub(
1050+
r"\b(VERIFY_[A-Z0-9_]+)\b", r'"\1"', verify_flags_config_str
1051+
)
1052+
1053+
# transform the string to a list of tuples - the first element of each tuple is a string containing the name of the flag,
1054+
# and the second is a boolean that indicates if the flad should be enabled or disabled
1055+
verify_flags_config = ast.literal_eval(verify_flags_config_str)
1056+
1057+
verify_flags_config_config_parsed = []
1058+
for flag, enabled in verify_flags_config:
1059+
if not hasattr(VerifyFlags, flag):
1060+
raise ValueError(f"Invalid verify flag: {flag}")
1061+
if not isinstance(enabled, bool):
1062+
raise ValueError(
1063+
f"Invalid verify flag enabled/disabled value: {enabled}"
1064+
)
1065+
verify_flags_config_config_parsed.append(
1066+
(getattr(VerifyFlags, flag), enabled)
1067+
)
1068+
1069+
kwargs["ssl_verify_flags_config"] = verify_flags_config_config_parsed
10241070
else:
10251071
valid_schemes = "redis://, rediss://, unix://"
10261072
raise ValueError(

redis/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Mapping,
1313
Optional,
1414
Set,
15+
Tuple,
1516
Type,
1617
Union,
1718
)
@@ -224,6 +225,7 @@ def __init__(
224225
ssl_keyfile: Optional[str] = None,
225226
ssl_certfile: Optional[str] = None,
226227
ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required",
228+
ssl_verify_flags_config: Optional[List[Tuple["ssl.VerifyFlags", bool]]] = None,
227229
ssl_ca_certs: Optional[str] = None,
228230
ssl_ca_path: Optional[str] = None,
229231
ssl_ca_data: Optional[str] = None,
@@ -330,6 +332,7 @@ def __init__(
330332
"ssl_keyfile": ssl_keyfile,
331333
"ssl_certfile": ssl_certfile,
332334
"ssl_cert_reqs": ssl_cert_reqs,
335+
"ssl_verify_flags_config": ssl_verify_flags_config,
333336
"ssl_ca_certs": ssl_ca_certs,
334337
"ssl_ca_data": ssl_ca_data,
335338
"ssl_check_hostname": ssl_check_hostname,

redis/cluster.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def parse_cluster_myshardid(resp, **options):
184184
"ssl_ca_data",
185185
"ssl_certfile",
186186
"ssl_cert_reqs",
187+
"ssl_verify_flags_config",
187188
"ssl_keyfile",
188189
"ssl_password",
189190
"ssl_check_hostname",

redis/connection.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import ast
12
import copy
23
import os
4+
import re
35
import socket
46
import sys
57
import threading
@@ -16,6 +18,7 @@
1618
List,
1719
Literal,
1820
Optional,
21+
Tuple,
1922
Type,
2023
TypeVar,
2124
Union,
@@ -68,8 +71,10 @@
6871

6972
if SSL_AVAILABLE:
7073
import ssl
74+
from ssl import VerifyFlags
7175
else:
7276
ssl = None
77+
VerifyFlags = None
7378

7479
if HIREDIS_AVAILABLE:
7580
import hiredis
@@ -1358,6 +1363,7 @@ def __init__(
13581363
ssl_keyfile=None,
13591364
ssl_certfile=None,
13601365
ssl_cert_reqs="required",
1366+
ssl_verify_flags_config: Optional[List[Tuple["VerifyFlags", bool]]] = None,
13611367
ssl_ca_certs=None,
13621368
ssl_ca_data=None,
13631369
ssl_check_hostname=True,
@@ -1376,7 +1382,19 @@ def __init__(
13761382
Args:
13771383
ssl_keyfile: Path to an ssl private key. Defaults to None.
13781384
ssl_certfile: Path to an ssl certificate. Defaults to None.
1379-
ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required".
1385+
ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1386+
or an ssl.VerifyMode. Defaults to "required".
1387+
ssl_verify_flags_config: A list with flags configuration to be set on the SSLContext. Defaults to None.
1388+
Valid format is as follows:
1389+
[
1390+
(config_flag, enabled/disabled),
1391+
...
1392+
]
1393+
Example:
1394+
[
1395+
(ssl.VERIFY_X509_STRICT, False), # disable strict
1396+
(ssl.VERIFY_X509_PARTIAL_CHAIN, True), # ensure partial chain is enabled
1397+
]
13801398
ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
13811399
ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
13821400
ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
@@ -1412,6 +1430,7 @@ def __init__(
14121430
)
14131431
ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
14141432
self.cert_reqs = ssl_cert_reqs
1433+
self.ssl_verify_flags_config = ssl_verify_flags_config
14151434
self.ca_certs = ssl_ca_certs
14161435
self.ca_data = ssl_ca_data
14171436
self.ca_path = ssl_ca_path
@@ -1451,6 +1470,12 @@ def _wrap_socket_with_ssl(self, sock):
14511470
context = ssl.create_default_context()
14521471
context.check_hostname = self.check_hostname
14531472
context.verify_mode = self.cert_reqs
1473+
if self.ssl_verify_flags_config:
1474+
for flag, enabled in self.ssl_verify_flags_config:
1475+
if enabled:
1476+
context.options |= flag
1477+
else:
1478+
context.options &= ~flag
14541479
if self.certfile or self.keyfile:
14551480
context.load_cert_chain(
14561481
certfile=self.certfile,
@@ -1632,6 +1657,34 @@ def parse_url(url):
16321657
if url.scheme == "rediss":
16331658
kwargs["connection_class"] = SSLConnection
16341659

1660+
if "ssl_verify_flags_config" in kwargs:
1661+
# flags are passed in as a string representation of a list,
1662+
# e.g. [(VERIFY_X509_STRICT, False), (VERIFY_X509_PARTIAL_CHAIN, True)]
1663+
# To parse it sucessfully, we need transform the flags to strings with quotes.
1664+
verify_flags_config_str = kwargs.pop("ssl_verify_flags_config")
1665+
# First wrap any VERIFY_* name in quotes
1666+
verify_flags_config_str = re.sub(
1667+
r"\b(VERIFY_[A-Z0-9_]+)\b", r'"\1"', verify_flags_config_str
1668+
)
1669+
1670+
# transform the string to a list of tuples - the first element of each tuple is a string containing the name of the flag,
1671+
# and the second is a boolean that indicates if the flad should be enabled or disabled
1672+
verify_flags_config = ast.literal_eval(verify_flags_config_str)
1673+
1674+
ssl_verify_flags_config_parsed = []
1675+
for flag, enabled in verify_flags_config:
1676+
if not hasattr(VerifyFlags, flag):
1677+
raise ValueError(f"Invalid ssl verify flag: {flag}")
1678+
if not isinstance(enabled, bool):
1679+
raise ValueError(
1680+
f"Invalid ssl verify flag enabled/disabled value: {enabled}"
1681+
)
1682+
ssl_verify_flags_config_parsed.append(
1683+
(getattr(VerifyFlags, flag), enabled)
1684+
)
1685+
1686+
kwargs["ssl_verify_flags_config"] = ssl_verify_flags_config_parsed
1687+
16351688
return kwargs
16361689

16371690

0 commit comments

Comments
 (0)