Skip to content
2 changes: 2 additions & 0 deletions src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,8 @@ async def connect(self, **kwargs) -> None:
proxy_user=self.proxy_user,
proxy_password=self.proxy_password,
snowflake_ocsp_mode=self._ocsp_mode(),
ocsp_root_certs_dict_lock_timeout=self._ocsp_root_certs_dict_lock_timeout,
ocsp_response_cache_file_name=self._ocsp_response_cache_filename,
trust_env=True, # Required for proxy support via environment variables
)
self._session_manager = SessionManagerFactory.get_manager(self._http_config)
Expand Down
8 changes: 7 additions & 1 deletion src/snowflake/connector/aio/_ocsp_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
)
from snowflake.connector.errors import RevocationCheckError
from snowflake.connector.network import PYTHON_CONNECTOR_USER_AGENT
from snowflake.connector.ocsp_snowflake import OCSPCache, OCSPResponseValidationResult
from snowflake.connector.ocsp_snowflake import (
OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT,
OCSPCache,
OCSPResponseValidationResult,
)
from snowflake.connector.ocsp_snowflake import OCSPServer as OCSPServerSync
from snowflake.connector.ocsp_snowflake import OCSPTelemetryData
from snowflake.connector.ocsp_snowflake import SnowflakeOCSP as SnowflakeOCSPSync
Expand Down Expand Up @@ -143,6 +147,7 @@ def __init__(
use_ocsp_cache_server=None,
use_post_method: bool = True,
use_fail_open: bool = True,
root_certs_dict_lock_timeout: int = OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT,
**kwargs,
) -> None:
self.test_mode = os.getenv("SF_OCSP_TEST_MODE", None)
Expand All @@ -151,6 +156,7 @@ def __init__(
logger.debug("WARNING - DRIVER CONFIGURED IN TEST MODE")

self._use_post_method = use_post_method
self._root_certs_dict_lock_timeout = root_certs_dict_lock_timeout
self.OCSP_CACHE_SERVER = OCSPServer(
top_level_domain=extract_top_level_domain_from_hostname(
kwargs.pop("hostname", None)
Expand Down
28 changes: 24 additions & 4 deletions src/snowflake/connector/aio/_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from aiohttp.typedefs import StrOrURL

from .. import OperationalError
from ..constants import OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT
from ..errorcode import ER_OCSP_RESPONSE_CERT_STATUS_REVOKED
from ..ssl_wrap_socket import FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME
from ._ocsp_asn1crypto import SnowflakeOCSPAsn1Crypto

if TYPE_CHECKING:
Expand Down Expand Up @@ -45,9 +45,13 @@ def __init__(
*args,
snowflake_ocsp_mode: OCSPMode = OCSPMode.FAIL_OPEN,
session_manager: SessionManager | None = None,
ocsp_root_certs_dict_lock_timeout: int = OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT,
ocsp_response_cache_file_name: str | None = None,
**kwargs,
):
self._snowflake_ocsp_mode = snowflake_ocsp_mode
self._ocsp_root_certs_dict_lock_timeout = ocsp_root_certs_dict_lock_timeout
self._ocsp_response_cache_file_name = ocsp_response_cache_file_name
if session_manager is None:
logger.warning(
"SessionManager instance was not passed to SSLConnector - OCSP will use default settings which may be distinct from the customer's specific one. Code should always pass such instance - verify why it isn't true in the current context"
Expand Down Expand Up @@ -99,9 +103,10 @@ async def validate_ocsp(
):

v = await SnowflakeOCSPAsn1Crypto(
ocsp_response_cache_uri=FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME,
ocsp_response_cache_uri=self._ocsp_response_cache_file_name,
use_fail_open=self._snowflake_ocsp_mode == OCSPMode.FAIL_OPEN,
hostname=hostname,
root_certs_dict_lock_timeout=self._ocsp_root_certs_dict_lock_timeout,
).validate(hostname, protocol, session_manager=session_manager)
if not v:
raise OperationalError(
Expand Down Expand Up @@ -140,6 +145,10 @@ class AioHttpConfig(BaseHttpConfig):
connector_factory: Callable[..., aiohttp.BaseConnector] = field(
default_factory=SnowflakeSSLConnectorFactory
)
ocsp_root_certs_dict_lock_timeout: int = (
OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT
)
ocsp_response_cache_file_name: str | None = None

trust_env: bool = True
"""Trust environment variables for proxy configuration (HTTP_PROXY, HTTPS_PROXY, NO_PROXY).
Expand All @@ -153,7 +162,13 @@ def get_connector(
) -> aiohttp.BaseConnector:
# We pass here only chosen attributes as kwargs to make the arguments received by the factory as compliant with the BaseConnector constructor interface as possible.
# We could consider passing the whole HttpConfig as kwarg to the factory if necessary in the future.
attributes_for_connector_factory = frozenset({"snowflake_ocsp_mode"})
attributes_for_connector_factory = frozenset(
{
"snowflake_ocsp_mode",
"ocsp_root_certs_dict_lock_timeout",
"ocsp_response_cache_file_name",
}
)

self_kwargs_for_connector_factory = {
attr_name: getattr(self, attr_name)
Expand Down Expand Up @@ -223,8 +238,13 @@ async def get(
use_pooling: bool | None = None,
**kwargs,
) -> aiohttp.ClientResponse:
async with self.use_session(url, use_pooling) as session:
if isinstance(timeout, tuple):
connect, total = timeout
timeout_obj = aiohttp.ClientTimeout(total=total, connect=connect)
else:
timeout_obj = aiohttp.ClientTimeout(total=timeout) if timeout else None

async with self.use_session(url, use_pooling) as session:
return await session.get(
url, headers=headers, timeout=timeout_obj, **kwargs
)
Expand Down
8 changes: 8 additions & 0 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
_DOMAIN_NAME_MAP,
_OAUTH_DEFAULT_SCOPE,
ENV_VAR_PARTNER,
OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT,
PARAMETER_AUTOCOMMIT,
PARAMETER_CLIENT_PREFETCH_THREADS,
PARAMETER_CLIENT_REQUEST_MFA_TOKEN,
Expand Down Expand Up @@ -242,6 +243,10 @@ def _get_private_bytes_from_file(
"internal_application_version": (CLIENT_VERSION, (type(None), str)),
"disable_ocsp_checks": (False, bool),
"ocsp_fail_open": (True, bool), # fail open on ocsp issues, default true
"ocsp_root_certs_dict_lock_timeout": (
OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT, # no timeout
int,
),
"inject_client_pause": (0, int), # snowflake internal
"session_parameters": (None, (type(None), dict)), # snowflake session parameters
"autocommit": (None, (type(None), bool)), # snowflake
Expand Down Expand Up @@ -443,6 +448,7 @@ class SnowflakeConnection:
validates the TLS certificate but doesn't check revocation status with OCSP provider.
ocsp_fail_open: Whether or not the connection is in fail open mode. Fail open mode decides if TLS certificates
continue to be validated. Revoked certificates are blocked. Any other exceptions are disregarded.
ocsp_root_certs_dict_lock_timeout: Timeout for the OCSP root certs dict lock in seconds. Default value is -1, which means no timeout.
session_id: The session ID of the connection.
user: The user name used in the connection.
host: The host name the connection attempts to connect to.
Expand Down Expand Up @@ -1545,6 +1551,8 @@ def __config(self, **kwargs):
WORKLOAD_IDENTITY_AUTHENTICATOR,
PROGRAMMATIC_ACCESS_TOKEN,
PAT_WITH_EXTERNAL_SESSION,
OAUTH_AUTHORIZATION_CODE,
OAUTH_CLIENT_CREDENTIALS,
}

if not (self._master_token and self._session_token):
Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/connector/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ class FileHeader(NamedTuple):

HTTP_HEADER_VALUE_OCTET_STREAM = "application/octet-stream"

# OCSP
OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT: int = -1


@unique
class OCSPMode(Enum):
Expand Down
7 changes: 7 additions & 0 deletions src/snowflake/connector/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
HTTP_HEADER_CONTENT_TYPE,
HTTP_HEADER_SERVICE_NAME,
HTTP_HEADER_USER_AGENT,
OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT,
)
from .description import (
CLIENT_NAME,
Expand Down Expand Up @@ -337,6 +338,12 @@ def __init__(
ssl_wrap_socket.FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME = (
self._connection._ocsp_response_cache_filename if self._connection else None
)
# OCSP root timeout
ssl_wrap_socket.FEATURE_ROOT_CERTS_DICT_LOCK_TIMEOUT = (
self._connection._ocsp_root_certs_dict_lock_timeout
if self._connection
else OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT
)

# This is to address the issue where requests hangs
_ = "dummy".encode("idna").decode("utf-8")
Expand Down
131 changes: 73 additions & 58 deletions src/snowflake/connector/ocsp_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from . import constants
from .backoff_policies import exponential_backoff
from .cache import CacheEntry, SFDictCache, SFDictFileCache
from .constants import OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT
from .telemetry import TelemetryField, generate_telemetry_data_dict
from .url_util import extract_top_level_domain_from_hostname, url_encode_str
from .util_text import _base64_bytes_to_str
Expand Down Expand Up @@ -1037,6 +1038,7 @@ def __init__(
use_ocsp_cache_server=None,
use_post_method: bool = True,
use_fail_open: bool = True,
root_certs_dict_lock_timeout: int = OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT,
**kwargs,
) -> None:
self.test_mode = os.getenv("SF_OCSP_TEST_MODE", None)
Expand All @@ -1045,6 +1047,7 @@ def __init__(
logger.debug("WARNING - DRIVER CONFIGURED IN TEST MODE")

self._use_post_method = use_post_method
self._root_certs_dict_lock_timeout = root_certs_dict_lock_timeout
self.OCSP_CACHE_SERVER = OCSPServer(
top_level_domain=extract_top_level_domain_from_hostname(
kwargs.pop("hostname", None)
Expand Down Expand Up @@ -1415,67 +1418,79 @@ def _check_ocsp_response_cache_server(

def _lazy_read_ca_bundle(self) -> None:
"""Reads the local cabundle file and cache it in memory."""
with SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK:
if SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
# return if already loaded
return

lock_acquired = SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK.acquire(
timeout=self._root_certs_dict_lock_timeout
)
if lock_acquired:
try:
ca_bundle = environ.get("REQUESTS_CA_BUNDLE") or environ.get(
"CURL_CA_BUNDLE"
)
if ca_bundle and path.exists(ca_bundle):
# if the user/application specifies cabundle.
self.read_cert_bundle(ca_bundle)
else:
import sys

# This import that depends on these libraries is to import certificates from them,
# we would like to have these as up to date as possible.
from requests import certs
if SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
# return if already loaded
return

if (
hasattr(certs, "__file__")
and path.exists(certs.__file__)
and path.exists(
path.join(path.dirname(certs.__file__), "cacert.pem")
)
):
# if cacert.pem exists next to certs.py in request
# package.
ca_bundle = path.join(
path.dirname(certs.__file__), "cacert.pem"
)
try:
ca_bundle = environ.get("REQUESTS_CA_BUNDLE") or environ.get(
"CURL_CA_BUNDLE"
)
if ca_bundle and path.exists(ca_bundle):
# if the user/application specifies cabundle.
self.read_cert_bundle(ca_bundle)
elif hasattr(sys, "_MEIPASS"):
# if pyinstaller includes cacert.pem
cabundle_candidates = [
["botocore", "vendored", "requests", "cacert.pem"],
["requests", "cacert.pem"],
["cacert.pem"],
]
for filename in cabundle_candidates:
ca_bundle = path.join(sys._MEIPASS, *filename)
if path.exists(ca_bundle):
self.read_cert_bundle(ca_bundle)
break
else:
logger.error("No cabundle file is found in _MEIPASS")
try:
import certifi

self.read_cert_bundle(certifi.where())
except Exception:
logger.debug("no certifi is installed. ignored.")

except Exception as e:
logger.error("Failed to read ca_bundle: %s", e)

if not SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
logger.error(
"No CA bundle file is found in the system. "
"Set REQUESTS_CA_BUNDLE to the file."
)
else:
import sys

# This import that depends on these libraries is to import certificates from them,
# we would like to have these as up to date as possible.
from requests import certs

if (
hasattr(certs, "__file__")
and path.exists(certs.__file__)
and path.exists(
path.join(path.dirname(certs.__file__), "cacert.pem")
)
):
# if cacert.pem exists next to certs.py in request
# package.
ca_bundle = path.join(
path.dirname(certs.__file__), "cacert.pem"
)
self.read_cert_bundle(ca_bundle)
elif hasattr(sys, "_MEIPASS"):
# if pyinstaller includes cacert.pem
cabundle_candidates = [
["botocore", "vendored", "requests", "cacert.pem"],
["requests", "cacert.pem"],
["cacert.pem"],
]
for filename in cabundle_candidates:
ca_bundle = path.join(sys._MEIPASS, *filename)
if path.exists(ca_bundle):
self.read_cert_bundle(ca_bundle)
break
else:
logger.error("No cabundle file is found in _MEIPASS")
try:
import certifi

self.read_cert_bundle(certifi.where())
except Exception:
logger.debug("no certifi is installed. ignored.")

except Exception as e:
logger.error("Failed to read ca_bundle: %s", e)

if not SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
logger.error(
"No CA bundle file is found in the system. "
"Set REQUESTS_CA_BUNDLE to the file."
)
finally:
SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK.release()
else:
logger.info(
"Failed to acquire lock for ROOT_CERTIFICATES_DICT_LOCK. "
"Skipping reading CA bundle."
)
return

@staticmethod
def _calculate_tolerable_validity(this_update: float, next_update: float) -> int:
Expand Down
6 changes: 5 additions & 1 deletion src/snowflake/connector/ssl_wrap_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import certifi
import OpenSSL.SSL

from .constants import OCSPMode
from .constants import OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT, OCSPMode
from .errorcode import ER_OCSP_RESPONSE_CERT_STATUS_REVOKED
from .errors import OperationalError
from .session_manager import SessionManager
Expand All @@ -31,6 +31,9 @@

DEFAULT_OCSP_MODE: OCSPMode = OCSPMode.FAIL_OPEN
FEATURE_OCSP_MODE: OCSPMode = DEFAULT_OCSP_MODE
FEATURE_ROOT_CERTS_DICT_LOCK_TIMEOUT: int = (
OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT
)

"""
OCSP Response cache file name
Expand Down Expand Up @@ -179,6 +182,7 @@ def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket:
ocsp_response_cache_uri=FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME,
use_fail_open=FEATURE_OCSP_MODE == OCSPMode.FAIL_OPEN,
hostname=server_hostname,
root_certs_dict_lock_timeout=FEATURE_ROOT_CERTS_DICT_LOCK_TIMEOUT,
).validate(server_hostname, ret.connection)
if not v:
raise OperationalError(
Expand Down
Loading
Loading