From a60869af172d8575ae31a8bff6653b7de26bec79 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Nov 2025 16:33:12 +0100 Subject: [PATCH 1/3] ENH: torch: allow negative indices in take_along_axis --- array_api_compat/torch/_aliases.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index d3857755..2903ac3e 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -819,7 +819,11 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: - return torch.take_along_dim(x, indices, dim=axis) + return torch.take_along_dim( + x, + torch.where(indices < 0, indices + x.shape[axis], indices), + dim=axis + ) def sign(x: Array, /) -> Array: From 92662a6f41e6c7c5d168a5c231e66c7a63c56674 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Nov 2025 17:23:05 +0100 Subject: [PATCH 2/3] ENH: torch: allow negative indices in take() --- array_api_compat/torch/_aliases.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 2903ac3e..7fc1194e 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -815,7 +815,12 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") axis = 0 - return torch.index_select(x, axis, indices, **kwargs) + return torch.index_select( + x, + axis, + torch.where(indices < 0, indices + x.shape[axis], indices), + **kwargs + ) def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: From 4355ab819c3c24c15861e3891bcdd58899f12421 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Nov 2025 10:35:35 +0100 Subject: [PATCH 3/3] MAINT: link to pytorch issue for negative indices --- array_api_compat/torch/_aliases.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 7fc1194e..4e8533f9 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -815,6 +815,8 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") axis = 0 + # torch does not support negative indices, + # see https://github.com/pytorch/pytorch/issues/146211 return torch.index_select( x, axis, @@ -824,6 +826,8 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: + # torch does not support negative indices, + # see https://github.com/pytorch/pytorch/issues/146211 return torch.take_along_dim( x, torch.where(indices < 0, indices + x.shape[axis], indices),