Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 262 additions & 0 deletions src/packaging/direct_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
from __future__ import annotations

import dataclasses
import sys
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol, TypeVar

if TYPE_CHECKING: # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

__all__ = [
"ArchiveInfo",
"DirInfo",
"DirectUrl",
"DirectUrlValidationError",
"VcsInfo",
]

_T = TypeVar("_T")


class _FromMappingProtocol(Protocol): # pragma: no cover
@classmethod
def _from_dict(cls, d: Mapping[str, Any]) -> Self: ...


_FromMappingProtocolT = TypeVar("_FromMappingProtocolT", bound=_FromMappingProtocol)


def _json_dict_factory(data: list[tuple[str, Any]]) -> dict[str, Any]:
return {key: value for key, value in data if value is not None}


def _get(d: Mapping[str, Any], expected_type: type[_T], key: str) -> _T | None:
"""Get a value from the dictionary and verify it's the expected type."""
if (value := d.get(key)) is None:
return None
if not isinstance(value, expected_type):
raise DirectUrlValidationError(
f"Unexpected type {type(value).__name__} "
f"(expected {expected_type.__name__})",
context=key,
)
return value


def _get_required(d: Mapping[str, Any], expected_type: type[_T], key: str) -> _T:
"""Get a required value from the dictionary and verify it's the expected type."""
if (value := _get(d, expected_type, key)) is None:
raise _DirectUrlRequiredKeyError(key)
return value


def _get_object(
d: Mapping[str, Any], target_type: type[_FromMappingProtocolT], key: str
) -> _FromMappingProtocolT | None:
"""Get a dictionary value from the dictionary and convert it to a dataclass."""
if (value := _get(d, Mapping, key)) is None: # type: ignore[type-abstract]
return None
try:
return target_type._from_dict(value)
except Exception as e:
raise DirectUrlValidationError(e, context=key) from e


class DirectUrlValidationError(Exception):
"""Raised when when input data is not spec-compliant."""

context: str | None = None
message: str

def __init__(
self,
cause: str | Exception,
*,
context: str | None = None,
) -> None:
if isinstance(cause, DirectUrlValidationError):
if cause.context:
self.context = (
f"{context}.{cause.context}" if context else cause.context
)
else:
self.context = context # pragma: no cover
self.message = cause.message
else:
self.context = context
self.message = str(cause)

def __str__(self) -> str:
if self.context:
return f"{self.message} in {self.context!r}"
return self.message


class _DirectUrlRequiredKeyError(DirectUrlValidationError):
def __init__(self, key: str) -> None:
super().__init__("Missing required value", context=key)


@dataclass(frozen=True, init=False)
class VcsInfo:
vcs: str
commit_id: str
requested_revision: str | None = None

def __init__(
self,
*,
vcs: str,
commit_id: str,
requested_revision: str | None = None,
) -> None:
object.__setattr__(self, "vcs", vcs)
object.__setattr__(self, "commit_id", commit_id)
object.__setattr__(self, "requested_revision", requested_revision)

@classmethod
def _from_dict(cls, d: Mapping[str, Any]) -> Self:
# We can't validate vcs value because is not closed.
return cls(
vcs=_get_required(d, str, "vcs"),
requested_revision=_get(d, str, "requested_revision"),
commit_id=_get_required(d, str, "commit_id"),
)


@dataclass(frozen=True, init=False)
class ArchiveInfo:
hashes: Mapping[str, str] | None = None
hash: str | None = None # Deprecated, use `hashes` instead

def __init__(
self,
*,
hashes: Mapping[str, str] | None = None,
hash: str | None = None,
) -> None:
object.__setattr__(self, "hashes", hashes)
object.__setattr__(self, "hash", hash)

@classmethod
def _from_dict(cls, d: Mapping[str, Any]) -> Self:
archive_info = cls(
hashes=_get(d, Mapping, "hashes"), # type: ignore[type-abstract]
hash=_get(d, str, "hash"),
)
hashes = archive_info.hashes or {}
if not all(isinstance(hash, str) for hash in hashes.values()):
raise DirectUrlValidationError(
"Hash values must be strings", context="hashes"
)
if archive_info.hash is not None:
if "=" not in archive_info.hash:
raise DirectUrlValidationError(
"Invalid hash format (expected '<algorithm>=<hash>')",
context="hash",
)
if archive_info.hashes is not None:
# if `hashes` are present, the legacy `hash` must match one of them
hash_algorithm, hash_value = archive_info.hash.split("=", 1)
if hash_algorithm not in hashes:
raise DirectUrlValidationError(
f"Algorithm {hash_algorithm!r} used in hash field "
f"is not present in hashes field",
context="hashes",
)
if hashes[hash_algorithm] != hash_value:
raise DirectUrlValidationError(
f"Algorithm {hash_algorithm!r} used in hash field "
f"has different value in hashes field",
context="hash",
)
return archive_info


@dataclass(frozen=True, init=False)
class DirInfo:
editable: bool | None = None

def __init__(
self,
*,
editable: bool | None = None,
) -> None:
object.__setattr__(self, "editable", editable)

@classmethod
def _from_dict(cls, d: Mapping[str, Any]) -> Self:
return cls(
editable=_get(d, bool, "editable"),
)


@dataclass(frozen=True, init=False)
class DirectUrl:
url: str
archive_info: ArchiveInfo | None = None
vcs_info: VcsInfo | None = None
dir_info: DirInfo | None = None
subdirectory: str | None = None # XXX Path or str?

def __init__(
self,
*,
url: str,
archive_info: ArchiveInfo | None = None,
vcs_info: VcsInfo | None = None,
dir_info: DirInfo | None = None,
subdirectory: str | None = None,
) -> None:
object.__setattr__(self, "url", url)
object.__setattr__(self, "archive_info", archive_info)
object.__setattr__(self, "vcs_info", vcs_info)
object.__setattr__(self, "dir_info", dir_info)
object.__setattr__(self, "subdirectory", subdirectory)

@classmethod
def _from_dict(cls, d: Mapping[str, Any]) -> Self:
direct_url = cls(
url=_get_required(d, str, "url"),
archive_info=_get_object(d, ArchiveInfo, "archive_info"),
vcs_info=_get_object(d, VcsInfo, "vcs_info"),
dir_info=_get_object(d, DirInfo, "dir_info"),
subdirectory=_get(d, str, "subdirectory"),
)
if (
bool(direct_url.vcs_info)
+ bool(direct_url.archive_info)
+ bool(direct_url.dir_info)
) != 1:
raise DirectUrlValidationError(
"Exactly one of vcs_info, archive_info, dir_info must be present"
)
if direct_url.dir_info is not None and not direct_url.url.startswith("file://"):
raise DirectUrlValidationError(
"URL scheme must be file:// when dir_info is present",
context="url",
)
# XXX subdirectory must be relative, can we, should we validate that here?
# XXX url MUST be stripped of any sensitive authentication information.
# We can't validate it here because it MAY contain git or other non security
# sensitive auth strings.
return direct_url

@classmethod
def from_dict(cls, d: Mapping[str, Any], /) -> Self:
return cls._from_dict(d)

def to_dict(self) -> Mapping[str, Any]:
return dataclasses.asdict(self, dict_factory=_json_dict_factory)

def validate(self) -> None:
"""Validate the DirectUrl instance against the specification.

Raises :class:`DirectUrlValidationError` otherwise.
"""
self.from_dict(self.to_dict())
Loading
Loading