Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
broadcast_shapes
cov
create_diagonal
default_dtype
expand_dims
isclose
kron
Expand Down
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
broadcast_shapes,
cov,
create_diagonal,
default_dtype,
expand_dims,
kron,
nunique,
Expand All @@ -27,6 +28,7 @@
"broadcast_shapes",
"cov",
"create_diagonal",
"default_dtype",
"expand_dims",
"isclose",
"kron",
Expand Down
46 changes: 41 additions & 5 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from collections.abc import Callable, Sequence
from types import ModuleType, NoneType
from typing import cast, overload
from typing import Literal, cast, overload

from ._at import at
from ._utils import _compat, _helpers
Expand All @@ -16,7 +16,7 @@
meta_namespace,
ndindex,
)
from ._utils._typing import Array
from ._utils._typing import Array, Device, DType

__all__ = [
"apply_where",
Expand Down Expand Up @@ -438,6 +438,44 @@ def create_diagonal(
return xp.reshape(diag, (*batch_dims, n, n))


def default_dtype(
xp: ModuleType,
kind: Literal[
"real floating", "complex floating", "integral", "indexing"
] = "real floating",
*,
device: Device | None = None,
) -> DType:
"""
Return the default dtype for the given namespace and device.

This is a convenience shorthand for
``xp.__array_namespace_info__().default_dtypes(device=device)[kind]``.

Parameters
----------
xp : array_namespace
The standard-compatible namespace for which to get the default dtype.
kind : {'real floating', 'complex floating', 'integral', 'indexing'}, optional
The kind of dtype to return. Default is 'real floating'.
device : Device, optional
The device for which to get the default dtype. Default: current device.

Returns
-------
dtype
The default dtype for the given namespace, kind, and device.
"""
dtypes = xp.__array_namespace_info__().default_dtypes(device=device)
try:
return dtypes[kind]
except KeyError as e:
domain = ("real floating", "complex floating", "integral", "indexing")
assert set(dtypes) == set(domain), f"Non-compliant namespace: {dtypes}"
msg = f"Unknown kind '{kind}'. Expected one of {domain}."
raise ValueError(msg) from e


def expand_dims(
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
) -> Array:
Expand Down Expand Up @@ -728,9 +766,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
x = xp.reshape(x, (-1,))
x = xp.sort(x)
mask = x != xp.roll(x, -1)
default_int = xp.__array_namespace_info__().default_dtypes(
device=_compat.device(x)
)["integral"]
default_int = default_dtype(xp, "integral", device=_compat.device(x))
return xp.maximum(
# Special cases:
# - array is size 0
Expand Down
38 changes: 28 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,8 @@ def xp(

if library.like(Backend.JAX):
_setup_jax(library)

elif library == Backend.TORCH_GPU:
import torch.cuda

if not torch.cuda.is_available():
pytest.skip("no CUDA device available")
xp.set_default_device("cuda")

elif library == Backend.TORCH: # CPU
xp.set_default_device("cpu")
elif library.like(Backend.TORCH):
_setup_torch(library)

yield xp

Expand All @@ -179,6 +171,23 @@ def _setup_jax(library: Backend) -> None:
jax.config.update("jax_default_device", device)


def _setup_torch(library: Backend) -> None:
import torch

# This is already the default, but some tests or env variables may change it.
# TODO test both float32 and float64, like in scipy.
torch.set_default_dtype(torch.float32)

if library == Backend.TORCH_GPU:
import torch.cuda

if not torch.cuda.is_available():
pytest.skip("no CUDA device available")
torch.set_default_device("cuda")
else: # TORCH
torch.set_default_device("cpu")


@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`
def da(
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
Expand All @@ -201,6 +210,15 @@ def jnp(
return xp


@pytest.fixture(params=[Backend.TORCH, Backend.TORCH_GPU])
def torch(request: pytest.FixtureRequest) -> ModuleType: # numpydoc ignore=PR01,RT01
"""Variant of the `xp` fixture that only yields torch."""
xp = pytest.importorskip("torch")
xp = array_namespace(xp.empty(0))
_setup_torch(request.param)
return xp


@pytest.fixture
def device(
library: Backend, xp: ModuleType
Expand Down
34 changes: 34 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
broadcast_shapes,
cov,
create_diagonal,
default_dtype,
expand_dims,
isclose,
kron,
Expand Down Expand Up @@ -517,6 +518,39 @@ def test_xp(self, xp: ModuleType):
xp_assert_equal(y, xp.asarray([[1, 0], [0, 2]]))


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no __array_namespace_info__")
class TestDefaultDType:
def test_basic(self, xp: ModuleType):
assert default_dtype(xp) == xp.empty(0).dtype

def test_kind(self, xp: ModuleType):
assert default_dtype(xp, "real floating") == xp.empty(0).dtype
assert default_dtype(xp, "complex floating") == (xp.empty(0) * 1j).dtype
assert default_dtype(xp, "integral") == xp.int64
assert default_dtype(xp, "indexing") == xp.int64

with pytest.raises(ValueError, match="Unknown kind"):
_ = default_dtype(xp, "foo") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

def test_device(self, xp: ModuleType, device: Device):
# Note: at the moment there are no known namespaces with
# device-specific default dtypes.
assert default_dtype(xp, device=None) == xp.empty(0).dtype
assert default_dtype(xp, device=device) == xp.empty(0).dtype

def test_torch(self, torch: ModuleType):
xp = torch
xp.set_default_dtype(xp.float64)
assert default_dtype(xp) == xp.float64
assert default_dtype(xp, "real floating") == xp.float64
assert default_dtype(xp, "complex floating") == xp.complex128

xp.set_default_dtype(xp.float32)
assert default_dtype(xp) == xp.float32
assert default_dtype(xp, "real floating") == xp.float32
assert default_dtype(xp, "complex floating") == xp.complex64


class TestExpandDims:
def test_single_axis(self, xp: ModuleType):
"""Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""
Expand Down