Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/momento/auth/credential_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import os
from dataclasses import dataclass
from typing import Dict, Optional
from warnings import warn

from momento.errors.exceptions import InvalidArgumentException
from momento.internal.services import Service

from . import momento_endpoint_resolver

Expand All @@ -27,6 +31,8 @@ def from_environment_variable(
) -> CredentialProvider:
"""Reads and parses a Momento auth token stored as an environment variable.

Deprecated as of v1.28.0. Use from_environment_variables_v2 instead.

Args:
env_var_name (str): Name of the environment variable from which the API key will be read
control_endpoint (Optional[str], optional): Optionally overrides the default control endpoint.
Expand All @@ -42,6 +48,11 @@ def from_environment_variable(
Returns:
CredentialProvider
"""
warn(
"from_environment_variable is deprecated, use from_environment_variables_v2 instead",
DeprecationWarning,
stacklevel=2,
)
api_key = os.getenv(env_var_name)
if not api_key:
raise RuntimeError(f"Missing required environment variable {env_var_name}")
Expand All @@ -56,6 +67,8 @@ def from_string(
) -> CredentialProvider:
"""Reads and parses a Momento auth token.

Deprecated as of v1.28.0. Use from_api_key_v2 or from_disposable_token instead.

Args:
auth_token (str): the Momento API key (previously: auth token)
control_endpoint (Optional[str], optional): Optionally overrides the default control endpoint.
Expand All @@ -68,6 +81,11 @@ def from_string(
Returns:
CredentialProvider
"""
warn(
"from_string is deprecated, use from_api_key_v2 or from_disposable_token instead",
DeprecationWarning,
stacklevel=2,
)
token_and_endpoints = momento_endpoint_resolver.resolve(auth_token)
control_endpoint = control_endpoint or token_and_endpoints.control_endpoint
cache_endpoint = cache_endpoint or token_and_endpoints.cache_endpoint
Expand Down Expand Up @@ -102,3 +120,86 @@ def _obscure(self, value: str) -> str:

def get_auth_token(self) -> str:
return self.auth_token

@staticmethod
def from_api_key_v2(api_key: str, endpoint: str) -> CredentialProvider:
"""Creates a CredentialProvider from a v2 API key and endpoint.

Args:
api_key (str): The v2 API key.
endpoint (str): The Momento service endpoint.

Returns:
CredentialProvider
"""
if len(api_key) == 0:
raise InvalidArgumentException("API key cannot be empty.", Service.AUTH)
if len(endpoint) == 0:
raise InvalidArgumentException("Endpoint cannot be empty.", Service.AUTH)

if not momento_endpoint_resolver._is_v2_api_key(api_key):
raise InvalidArgumentException(
"Received an invalid v2 API key. Are you using the correct key and the correct CredentialProvider method?",
Service.AUTH,
)
return CredentialProvider(
auth_token=api_key,
control_endpoint=momento_endpoint_resolver._MOMENTO_CONTROL_ENDPOINT_PREFIX + endpoint,
cache_endpoint=momento_endpoint_resolver._MOMENTO_CACHE_ENDPOINT_PREFIX + endpoint,
token_endpoint=momento_endpoint_resolver._MOMENTO_TOKEN_ENDPOINT_PREFIX + endpoint,
port=443,
)

@staticmethod
def from_environment_variables_v2(
api_key_env_var: str = "MOMENTO_API_KEY", endpoint_env_var: str = "MOMENTO_ENDPOINT"
) -> CredentialProvider:
"""Creates a CredentialProvider from an endpoint and v2 API key stored in the environment variables MOMENTO_API_KEY and MOMENTO_ENDPOINT.

Args:
api_key_env_var (str): Optionally provide an alternate environment variable name from which the v2 API key will be read.
endpoint_env_var (str): Optionally provide an alternate environment variable name from which the Momento service endpoint will be read.

Returns:
CredentialProvider
"""
if len(api_key_env_var) == 0:
raise InvalidArgumentException("API key environment variable name cannot be empty.", Service.AUTH)
if len(endpoint_env_var) == 0:
raise InvalidArgumentException("Endpoint environment variable name cannot be empty.", Service.AUTH)

api_key = os.getenv(api_key_env_var)
if not api_key:
raise RuntimeError(f"Missing required environment variable {api_key_env_var}")
endpoint = os.getenv(endpoint_env_var)
if not endpoint:
raise RuntimeError(f"Missing required environment variable {endpoint_env_var}")

if not momento_endpoint_resolver._is_v2_api_key(api_key):
raise InvalidArgumentException(
"Received an invalid v2 API key. Are you using the correct key? Or did you mean to use `from_environment_variable()` with a legacy key instead?",
Service.AUTH,
)
return CredentialProvider.from_api_key_v2(api_key, endpoint)

@staticmethod
def from_disposable_token(auth_token: str) -> CredentialProvider:
"""Reads and parses a Momento disposable auth token.

Args:
auth_token (str): the Momento disposable auth token

Returns:
CredentialProvider
"""
if len(auth_token) == 0:
raise InvalidArgumentException("Disposable token cannot be empty.", Service.AUTH)
token_and_endpoints = momento_endpoint_resolver.resolve(auth_token)
auth_token = token_and_endpoints.auth_token
return CredentialProvider(
auth_token,
token_and_endpoints.control_endpoint,
token_and_endpoints.cache_endpoint,
token_and_endpoints.token_endpoint,
443,
)
25 changes: 25 additions & 0 deletions src/momento/auth/momento_endpoint_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
_MOMENTO_TOKEN_ENDPOINT_PREFIX = "token."
_CONTROL_ENDPOINT_CLAIM_ID = "cp"
_CACHE_ENDPOINT_CLAIM_ID = "c"
_API_KEY_TYPE_CLAIM_ID = "t"
_GLOBAL_API_KEY_TYPE = "g"


@dataclass
Expand All @@ -31,6 +33,14 @@ class _Base64DecodedV1Token:


def resolve(auth_token: str) -> _TokenAndEndpoints:
"""Helper function used by from_string and from_disposable_token to parse legacy and v1 auth tokens.

Args:
auth_token (str): The auth token to be resolved.

Returns:
_TokenAndEndpoints
"""
if not auth_token:
raise InvalidArgumentException("malformed auth token", Service.AUTH)

Expand All @@ -44,6 +54,11 @@ def resolve(auth_token: str) -> _TokenAndEndpoints:
auth_token=info["api_key"], # type: ignore[misc]
)
else:
if _is_v2_api_key(auth_token):
raise InvalidArgumentException(
"Unexpectedly received a v2 API key. Are you using the correct key and the correct CredentialProvider method?",
Service.AUTH,
)
return _get_endpoint_from_token(auth_token)


Expand All @@ -67,3 +82,13 @@ def _is_base64(value: Union[bytes, str]) -> bool:
return base64.b64encode(base64.b64decode(value)) == value
except Exception:
return False


def _is_v2_api_key(key: str) -> bool:
if _is_base64(key):
return False
try:
claims = jwt.decode(key, options={"verify_signature": False}) # type: ignore[misc]
return _API_KEY_TYPE_CLAIM_ID in claims and claims[_API_KEY_TYPE_CLAIM_ID] == _GLOBAL_API_KEY_TYPE # type: ignore[misc]
except DecodeError:
return False
181 changes: 181 additions & 0 deletions tests/momento/auth/test_credential_provider.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import base64
import json
import os
import re

import jwt
import pytest
from momento.auth.credential_provider import CredentialProvider
from momento.auth.momento_endpoint_resolver import _Base64DecodedV1Token
from momento.errors.exceptions import InvalidArgumentException

from tests.utils import uuid_str

Expand All @@ -23,6 +25,15 @@
os.environ[test_env_var_name] = test_token
os.environ[test_v1_env_var_name] = test_encoded_v1_token.decode("utf-8")

# For v2 API key tests
test_v2_key_message = {"t": "g", "jti": "some-id"}
test_v2_api_key = jwt.encode(test_v2_key_message, "secret", algorithm="HS512")
test_v2_key_env_var_name = "MOMENTO_API_KEY"
test_v2_endpoint = "testEndpoint"
test_v2_endpoint_env_var_name = "MOMENTO_ENDPOINT"
os.environ[test_v2_key_env_var_name] = test_v2_api_key
os.environ[test_v2_endpoint_env_var_name] = test_v2_endpoint


@pytest.mark.parametrize(
"provider, auth_token, control_endpoint, cache_endpoint",
Expand Down Expand Up @@ -97,3 +108,173 @@ def test_endpoints(provider: CredentialProvider, auth_token: str, control_endpoi
def test_env_token_raises_if_not_exists() -> None:
with pytest.raises(RuntimeError, match=r"Missing required environment variable"):
CredentialProvider.from_environment_variable(env_var_name=uuid_str())


@pytest.mark.parametrize(
"provider, expected_api_key, expected_control_endpoint, expected_cache_endpoint, expected_token_endpoint",
[
(
CredentialProvider.from_api_key_v2(
api_key=test_v2_api_key,
endpoint=test_v2_endpoint,
),
test_v2_api_key,
f"control.{test_v2_endpoint}",
f"cache.{test_v2_endpoint}",
f"token.{test_v2_endpoint}",
),
(
CredentialProvider.from_environment_variables_v2(
api_key_env_var=test_v2_key_env_var_name,
endpoint_env_var=test_v2_endpoint_env_var_name,
),
test_v2_api_key,
f"control.{test_v2_endpoint}",
f"cache.{test_v2_endpoint}",
f"token.{test_v2_endpoint}",
),
(
CredentialProvider.from_environment_variables_v2(),
test_v2_api_key,
f"control.{test_v2_endpoint}",
f"cache.{test_v2_endpoint}",
f"token.{test_v2_endpoint}",
),
],
)
def test_v2_api_key_endpoints(
provider: CredentialProvider,
expected_api_key: str,
expected_control_endpoint: str,
expected_cache_endpoint: str,
expected_token_endpoint: str,
) -> None:
assert provider.auth_token == expected_api_key
assert provider.control_endpoint == expected_control_endpoint
assert provider.cache_endpoint == expected_cache_endpoint
assert provider.token_endpoint == expected_token_endpoint


def test_v2_key_from_string_raises_if_api_key_empty() -> None:
with pytest.raises(InvalidArgumentException, match="API key cannot be empty"):
CredentialProvider.from_api_key_v2(api_key="", endpoint=test_v2_endpoint)


def test_v2_key_from_string_raises_if_endpoint_empty() -> None:
with pytest.raises(InvalidArgumentException, match="Endpoint cannot be empty"):
CredentialProvider.from_api_key_v2(api_key=test_v2_api_key, endpoint="")


def test_v2_key_from_env_raises_if_env_var_name_empty() -> None:
with pytest.raises(InvalidArgumentException, match="API key environment variable name cannot be empty"):
CredentialProvider.from_environment_variables_v2(
api_key_env_var="", endpoint_env_var=test_v2_endpoint_env_var_name
)


def test_v2_key_from_env_raises_if_env_var_missing() -> None:
with pytest.raises(RuntimeError, match="Missing required environment variable"):
CredentialProvider.from_environment_variables_v2(
api_key_env_var=uuid_str(), endpoint_env_var=test_v2_endpoint_env_var_name
)


def test_v2_key_from_env_raises_if_endpoint_empty() -> None:
with pytest.raises(InvalidArgumentException, match="Endpoint environment variable name cannot be empty"):
CredentialProvider.from_environment_variables_v2(api_key_env_var=test_v2_key_env_var_name, endpoint_env_var="")


def test_v2_key_from_env_raises_if_api_key_empty_string() -> None:
empty_api_key_env_var = uuid_str()
os.environ[empty_api_key_env_var] = ""
with pytest.raises(RuntimeError, match="Missing required environment variable"):
CredentialProvider.from_environment_variables_v2(
api_key_env_var=empty_api_key_env_var, endpoint_env_var=test_v2_endpoint_env_var_name
)


def test_v2_key_from_string_raises_if_base64_api_key() -> None:
with pytest.raises(
InvalidArgumentException,
match=re.escape(
"Received an invalid v2 API key. Are you using the correct key and the correct CredentialProvider method?"
),
):
CredentialProvider.from_api_key_v2(
api_key=test_encoded_v1_token.decode("utf-8"), endpoint=test_v2_endpoint_env_var_name
)


def test_v2_key_from_env_raises_if_base64_api_key() -> None:
with pytest.raises(
InvalidArgumentException,
match=re.escape(
"Received an invalid v2 API key. Are you using the correct key? Or did you mean to use `from_environment_variable()` with a legacy key instead?"
),
):
CredentialProvider.from_environment_variables_v2(
api_key_env_var=test_v1_env_var_name, endpoint_env_var=test_v2_endpoint_env_var_name
)


def test_v2_key_from_string_raises_if_pre_v1_token() -> None:
with pytest.raises(
InvalidArgumentException,
match=re.escape(
"Received an invalid v2 API key. Are you using the correct key and the correct CredentialProvider method?"
),
):
CredentialProvider.from_api_key_v2(api_key=test_token, endpoint=test_v2_endpoint_env_var_name)


def test_v2_key_from_env_raises_if_pre_v1_token() -> None:
with pytest.raises(
InvalidArgumentException,
match=re.escape(
"Received an invalid v2 API key. Are you using the correct key? Or did you mean to use `from_environment_variable()` with a legacy key instead?"
),
):
CredentialProvider.from_environment_variables_v2(
api_key_env_var=test_env_var_name, endpoint_env_var=test_v2_endpoint_env_var_name
)


def test_v2_key_provided_to_from_string() -> None:
with pytest.raises(
InvalidArgumentException,
match=re.escape(
"Unexpectedly received a v2 API key. Are you using the correct key and the correct CredentialProvider method?"
),
):
CredentialProvider.from_string(auth_token=test_v2_api_key)


def test_v2_key_provided_to_from_disposable_token() -> None:
with pytest.raises(
InvalidArgumentException,
match=re.escape(
"Unexpectedly received a v2 API key. Are you using the correct key and the correct CredentialProvider method?"
),
):
CredentialProvider.from_disposable_token(auth_token=test_v2_api_key)


def test_from_disposable_token_raises_if_token_empty() -> None:
with pytest.raises(InvalidArgumentException, match="Disposable token cannot be empty."):
CredentialProvider.from_disposable_token(auth_token="")


def test_from_disposable_token_accepts_v1_api_key() -> None:
provider = CredentialProvider.from_disposable_token(auth_token=test_encoded_v1_token.decode("utf-8"))
assert provider.auth_token == test_v1_api_key
assert provider.control_endpoint == "control.test.momentohq.com"
assert provider.cache_endpoint == "cache.test.momentohq.com"
assert provider.token_endpoint == "token.test.momentohq.com"


def test_from_disposable_token_accepts_pre_v1_token() -> None:
provider = CredentialProvider.from_disposable_token(auth_token=test_token)
assert provider.auth_token == test_token
assert provider.control_endpoint == test_control_endpoint
assert provider.cache_endpoint == test_cache_endpoint
assert provider.token_endpoint == f"token.{test_cache_endpoint}"
Loading