diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index d3857755..4e8533f9 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -815,11 +815,24 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") axis = 0 - return torch.index_select(x, axis, indices, **kwargs) + # torch does not support negative indices, + # see https://github.com/pytorch/pytorch/issues/146211 + return torch.index_select( + x, + axis, + torch.where(indices < 0, indices + x.shape[axis], indices), + **kwargs + ) def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: - return torch.take_along_dim(x, indices, dim=axis) + # torch does not support negative indices, + # see https://github.com/pytorch/pytorch/issues/146211 + return torch.take_along_dim( + x, + torch.where(indices < 0, indices + x.shape[axis], indices), + dim=axis + ) def sign(x: Array, /) -> Array: