Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,12 @@
from .telemetry import TelemetryClient, TelemetryData, TelemetryField
from .time_util import HeartBeatTimer, get_time_millis
from .url_util import extract_top_level_domain_from_hostname
from .util_text import construct_hostname, parse_account, split_statements
from .util_text import (
construct_hostname,
is_valid_account_identifier,
parse_account,
split_statements,
)
from .wif_util import AttestationProvider

if sys.version_info >= (3, 13) or typing.TYPE_CHECKING:
Expand Down Expand Up @@ -664,6 +669,18 @@ def __init__(
self._file_operation_parser = FileOperationParser(self)
self._stream_downloader = StreamDownloader(self)

def _validate_account(self, account_str):
if not is_valid_account_identifier(account_str):
Error.errorhandler_wrapper(
self,
None,
ProgrammingError,
{
"msg": "Invalid account identifier: only letters, digits, '_' and '-' allowed; no dots or slashes",
"errno": ER_NO_ACCOUNT_NAME,
},
)

# Deprecated
@property
def insecure_mode(self) -> bool:
Expand Down Expand Up @@ -1743,8 +1760,11 @@ def __config(self, **kwargs):
ProgrammingError,
{"msg": "Account must be specified", "errno": ER_NO_ACCOUNT_NAME},
)
if self._account and "." in self._account:
self._account = parse_account(self._account)

if self._account:
self._validate_account(self._account)
if "." in self._account:
self._account = parse_account(self._account)

if not isinstance(self._backoff_policy, Callable) or not isinstance(
self._backoff_policy(), Iterator
Expand Down
18 changes: 18 additions & 0 deletions src/snowflake/connector/util_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,24 @@ def _is_china_region(r: str) -> bool:
return host


ACCOUNT_ID_VALIDATOR_RE = re.compile(r"^[A-Za-z0-9_-]+$")


def is_valid_account_identifier(account: str) -> bool:
"""Validate the Snowflake account identifier format.

The account identifier must be a single label (no dots or slashes) composed
only of ASCII letters, digits, underscores, or hyphens.
"""
if not isinstance(account, str) or not account:
return False

if "/" in account or "\\" in account:
return False

return all(bool(ACCOUNT_ID_VALIDATOR_RE.fullmatch(p)) for p in account.split("."))


def parse_account(account):
url_parts = account.split(".")
# if this condition is true, then we have some extra
Expand Down
47 changes: 46 additions & 1 deletion test/unit/test_parse_account.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#!/usr/bin/env python
from __future__ import annotations

from snowflake.connector.util_text import parse_account
import pytest

from snowflake.connector import ProgrammingError, connect
from snowflake.connector.util_text import is_valid_account_identifier, parse_account


def test_parse_account_basic():
Expand All @@ -12,3 +15,45 @@ def test_parse_account_basic():
assert (
parse_account("account1-jkabfvdjisoa778wqfgeruishafeuw89q.global") == "account1"
)


@pytest.mark.parametrize(
"value",
[
"abc",
"aaa.bbb.ccc",
"aaa.bbb.ccc.ddd" "ABC",
"a_b-c1",
"account1",
"my_account",
"my-account",
"account_123",
"ACCOUNT_NAME",
],
)
def test_is_valid_account_identifier(value):
assert is_valid_account_identifier(value) is True


@pytest.mark.parametrize(
"value",
[
"a/b",
"a\\b",
"aa.bb.ccc/dddd",
"account@domain",
"account name",
"account\ttab",
"account\nnewline",
"account:port",
"account;semicolon",
"account'quote",
'account"doublequote',
],
)
def test_is_invalid_account_identifier(value):
assert is_valid_account_identifier(value) is False
with pytest.raises(ProgrammingError) as err:
connect(account=value, user="jdoe", password="***")

assert "Invalid account identifier" in str(err)
Loading