Skip to content

Commit 6c42ca3

Browse files
committed
Fix FixNumpyArrayDimTypeVar for pybind v3.0.0
1 parent 0a566ba commit 6c42ca3

File tree

19 files changed

+247
-116
lines changed

19 files changed

+247
-116
lines changed

pybind11_stubgen/parser/mixins/fix.py

Lines changed: 139 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040

4141

4242
class RemoveSelfAnnotation(IParser):
43-
4443
__any_t_name = QualifiedName.from_str("Any")
4544
__typing_any_t_name = QualifiedName.from_str("typing.Any")
4645

@@ -632,10 +631,19 @@ def report_error(self, error: ParserError) -> None:
632631

633632

634633
class FixNumpyArrayDimTypeVar(IParser):
635-
__array_names: set[QualifiedName] = {QualifiedName.from_str("numpy.ndarray")}
634+
__array_names: set[QualifiedName] = {
635+
QualifiedName.from_str("numpy.ndarray"),
636+
QualifiedName.from_str("numpy.typing.ArrayLike"),
637+
QualifiedName.from_str("numpy.typing.NDArray"),
638+
}
639+
__typing_annotated_names = {
640+
QualifiedName.from_str("typing.Annotated"),
641+
QualifiedName.from_str("typing_extensions.Annotated"),
642+
}
636643
numpy_primitive_types = FixNumpyArrayDimAnnotation.numpy_primitive_types
637644

638645
__DIM_VARS: set[str] = set()
646+
__DIM_STRING_PATTERN = re.compile(r'"\[(.*?)\]"')
639647

640648
def handle_module(
641649
self, path: QualifiedName, module: types.ModuleType
@@ -659,85 +667,155 @@ def handle_module(
659667
)
660668

661669
self.__DIM_VARS.clear()
662-
663670
return result
664671

665672
def parse_annotation_str(
666673
self, annotation_str: str
667674
) -> ResolvedType | InvalidExpression | Value:
668-
# Affects types of the following pattern:
669-
# numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
670-
# Replace with:
671-
# numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]
672-
673675
result = super().parse_annotation_str(annotation_str)
674-
675676
if not isinstance(result, ResolvedType):
676677
return result
677678

678679
# handle unqualified, single-letter annotation as a TypeVar
679680
if len(result.name) == 1 and len(result.name[0]) == 1:
680681
result.name = QualifiedName.from_str(result.name[0].upper())
681682
self.__DIM_VARS.add(result.name[0])
683+
return result
682684

683-
if result.name not in self.__array_names:
685+
if result.name == QualifiedName.from_str("numpy.ndarray"):
686+
parameters = self._handle_old_style_numpy_array(result.parameters)
687+
elif result.name in self.__array_names:
688+
parameters = self._handle_new_style_numpy_array([result])
689+
elif result.name in self.__typing_annotated_names:
690+
parameters = self._handle_new_style_numpy_array(result.parameters)
691+
else:
692+
parameters = None
693+
if parameters is None: # Failure.
684694
return result
695+
return ResolvedType(
696+
name=QualifiedName.from_str("numpy.ndarray"), parameters=parameters
697+
)
698+
699+
def _process_numpy_array_type(
700+
self, scalar_type_name: QualifiedName, dimensions: list[int | str] | None
701+
) -> tuple[ResolvedType, ResolvedType]:
702+
# Pybind annotates a bool Python type, which cannot be used with
703+
# numpy.dtype because it does not inherit from numpy.generic.
704+
# Only numpy.bool_ works reliably with both NumPy 1.x and 2.x.
705+
if str(scalar_type_name) == "bool":
706+
scalar_type_name = QualifiedName.from_str("numpy.bool_")
707+
dtype = ResolvedType(
708+
name=QualifiedName.from_str("numpy.dtype"),
709+
parameters=[ResolvedType(name=scalar_type_name)],
710+
)
711+
712+
shape = self.parse_annotation_str("Any")
713+
if dimensions is not None and len(dimensions) > 0:
714+
shape = self.parse_annotation_str("Tuple")
715+
assert isinstance(shape, ResolvedType)
716+
shape.parameters = []
717+
for dim in dimensions:
718+
if isinstance(dim, int):
719+
literal_dim = self.parse_annotation_str("Literal")
720+
assert isinstance(literal_dim, ResolvedType)
721+
literal_dim.parameters = [Value(repr=str(dim))]
722+
shape.parameters.append(literal_dim)
723+
else:
724+
shape.parameters.append(
725+
ResolvedType(name=QualifiedName.from_str(dim.upper()))
726+
)
727+
return shape, dtype
728+
729+
def _handle_new_style_numpy_array(
730+
self, parameters: list[ResolvedType | Value | InvalidExpression] | None
731+
) -> list[ResolvedType] | None:
732+
# Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, n]"]
733+
# Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]"]
734+
# Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]", "flags.writeable", "flags.c_contiguous"]
735+
if parameters is None or len(parameters) == 0:
736+
return
737+
738+
array_type, *parameters = parameters
739+
if (
740+
not isinstance(array_type, ResolvedType)
741+
or array_type.name not in self.__array_names
742+
):
743+
return
744+
745+
dims_and_flags: Sequence[ResolvedType | Value | InvalidExpression]
746+
if array_type.name == QualifiedName.from_str("numpy.typing.ArrayLike"):
747+
if not parameters:
748+
return
749+
scalar_type, *dims_and_flags = parameters
750+
elif array_type.name == QualifiedName.from_str("numpy.typing.NDArray"):
751+
if array_type.parameters is None or len(array_type.parameters) == 0:
752+
return
753+
[scalar_type] = array_type.parameters
754+
dims_and_flags = parameters
755+
elif array_type.name == QualifiedName.from_str("numpy.ndarray"):
756+
_, dtype_param = array_type.parameters
757+
if not (
758+
isinstance(dtype_param, ResolvedType)
759+
and dtype_param.name == QualifiedName.from_str("numpy.dtype")
760+
and dtype_param.parameters
761+
):
762+
return
763+
[scalar_type] = dtype_param.parameters
764+
dims_and_flags = parameters
765+
else:
766+
return
767+
scalar_type_name = scalar_type.name
768+
if scalar_type_name not in self.numpy_primitive_types:
769+
return
770+
771+
dims: list[int | str] | None = None
772+
if dims_and_flags:
773+
dims_str, *flags = dims_and_flags
774+
del flags # Unused.
775+
if isinstance(dims_str, Value):
776+
match = self.__DIM_STRING_PATTERN.search(dims_str.repr)
777+
if match:
778+
dims_str_content = match.group(1)
779+
dims_list = [
780+
d.strip() for d in dims_str_content.split(",") if d.strip()
781+
]
782+
if dims_list:
783+
dims = self.__to_dims_from_strings(dims_list)
784+
785+
return self._process_numpy_array_type(scalar_type_name, dims)
786+
787+
def _handle_old_style_numpy_array(
788+
self, parameters: list[ResolvedType | Value | InvalidExpression] | None
789+
) -> list[ResolvedType] | None:
790+
# Affects types of the following pattern:
791+
# numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
792+
# Replace with:
793+
# numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]
685794

686795
# ndarray is generic and should have 2 type arguments
687-
if result.parameters is None or len(result.parameters) == 0:
688-
result.parameters = [
796+
if parameters is None or len(parameters) == 0:
797+
return [
689798
self.parse_annotation_str("Any"),
690799
ResolvedType(
691800
name=QualifiedName.from_str("numpy.dtype"),
692801
parameters=[self.parse_annotation_str("Any")],
693802
),
694803
]
695-
return result
696-
697-
scalar_with_dims = result.parameters[0] # e.g. numpy.float64[32, 32]
698804

805+
scalar_with_dims = parameters[0] # e.g. numpy.float64[32, 32]
699806
if (
700807
not isinstance(scalar_with_dims, ResolvedType)
701808
or scalar_with_dims.name not in self.numpy_primitive_types
702809
):
703-
return result
704-
705-
name = scalar_with_dims.name
706-
# Pybind annotates a bool Python type, which cannot be used with
707-
# numpy.dtype because it does not inherit from numpy.generic.
708-
# Only numpy.bool_ works reliably with both NumPy 1.x and 2.x.
709-
if str(name) == "bool":
710-
name = QualifiedName.from_str("numpy.bool_")
711-
dtype = ResolvedType(
712-
name=QualifiedName.from_str("numpy.dtype"),
713-
parameters=[ResolvedType(name=name)],
714-
)
810+
return
715811

716-
shape = self.parse_annotation_str("Any")
812+
dims: list[int | str] | None = None
717813
if (
718814
scalar_with_dims.parameters is not None
719815
and len(scalar_with_dims.parameters) > 0
720816
):
721817
dims = self.__to_dims(scalar_with_dims.parameters)
722-
if dims is not None:
723-
shape = self.parse_annotation_str("Tuple")
724-
assert isinstance(shape, ResolvedType)
725-
shape.parameters = []
726-
for dim in dims:
727-
if isinstance(dim, int):
728-
# self.parse_annotation_str will qualify Literal with either
729-
# typing or typing_extensions and add the import to the module
730-
literal_dim = self.parse_annotation_str("Literal")
731-
assert isinstance(literal_dim, ResolvedType)
732-
literal_dim.parameters = [Value(repr=str(dim))]
733-
shape.parameters.append(literal_dim)
734-
else:
735-
shape.parameters.append(
736-
ResolvedType(name=QualifiedName.from_str(dim))
737-
)
738-
739-
result.parameters = [shape, dtype]
740-
return result
818+
return self._process_numpy_array_type(scalar_with_dims.name, dims)
741819

742820
def __to_dims(
743821
self, dimensions: Sequence[ResolvedType | Value | InvalidExpression]
@@ -756,6 +834,20 @@ def __to_dims(
756834
result.append(dim)
757835
return result
758836

837+
def __to_dims_from_strings(
838+
self, dimensions: Sequence[str]
839+
) -> list[int | str] | None:
840+
result: list[int | str] = []
841+
for dim_str in dimensions:
842+
try:
843+
dim = int(dim_str)
844+
except ValueError:
845+
dim = dim_str
846+
if len(dim) == 1: # Assuming single letter dims are type vars
847+
self.__DIM_VARS.add(dim.upper()) # Add uppercase to TypeVar set
848+
result.append(dim)
849+
return result
850+
759851
def report_error(self, error: ParserError) -> None:
760852
if (
761853
isinstance(error, NameResolutionError)

tests/check-demo-stubs-generation.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
set -e
3+
set -ex
44

55
function parse_args() {
66

@@ -30,8 +30,8 @@ function parse_args() {
3030
if [ -z "$STUBS_SUB_DIR" ]; then usage "STUBS_SUB_DIR is not set"; fi;
3131
if [ -z "$NUMPY_FORMAT" ]; then usage "NUMPY_FORMAT is not set"; fi;
3232

33-
TESTS_ROOT="$(readlink -m "$(dirname "$0")")"
34-
STUBS_DIR=$(readlink -m "${TESTS_ROOT}/${STUBS_SUB_DIR}")
33+
TESTS_ROOT="$(greadlink -m "$(dirname "$0")")"
34+
STUBS_DIR=$(greadlink -m "${TESTS_ROOT}/${STUBS_SUB_DIR}")
3535
}
3636

3737
remove_stubs() {

tests/demo-lib/include/demo/Foo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace demo{
55

66

77
class CppException : public std::runtime_error {
8-
using std::runtime_error::runtime_error;
8+
//using std::runtime_error;
99
};
1010

1111
struct Foo {

tests/install-demo-module.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function parse_args() {
2727
# verify params
2828
if [ -z "$PYBIND11_BRANCH" ]; then usage "PYBIND11_BRANCH is not set"; fi;
2929

30-
TESTS_ROOT="$(readlink -m "$(dirname "$0")")"
30+
TESTS_ROOT="$(greadlink -m "$(dirname "$0")")"
3131
PROJECT_ROOT="${TESTS_ROOT}/.."
3232
TEMP_DIR="${PROJECT_ROOT}/tmp/pybind11-${PYBIND11_BRANCH}"
3333
INSTALL_PREFIX="${TEMP_DIR}/install"
@@ -67,7 +67,7 @@ install_demo() {
6767

6868
install_pydemo() {
6969
(
70-
export CMAKE_PREFIX_PATH="$(readlink -m "${INSTALL_PREFIX}"):$(cmeel cmake)";
70+
export CMAKE_PREFIX_PATH="$(greadlink -m "${INSTALL_PREFIX}"):$(cmeel cmake)";
7171
export CMAKE_ARGS="-DCMAKE_CXX_STANDARD=17";
7272
pip install --force-reinstall "${TESTS_ROOT}/py-demo"
7373
)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "modules.h"
22

3-
namespace {
3+
namespace mymodules {
44
struct Dummy {
55
int regular_method(int x) { return x + 1; }
66
static int static_method(int x) { return x + 1; }
@@ -9,8 +9,8 @@ struct Dummy {
99
} // namespace
1010

1111
void bind_methods_module(py::module&& m) {
12-
auto &&pyDummy = py::class_<Dummy>(m, "Dummy");
12+
auto &&pyDummy = py::class_<mymodules::Dummy>(m, "Dummy");
1313

14-
pyDummy.def_static("static_method", &Dummy::static_method);
15-
pyDummy.def("regular_method", &Dummy::regular_method);
14+
pyDummy.def_static("static_method", &mymodules::Dummy::static_method);
15+
pyDummy.def("regular_method", &mymodules::Dummy::regular_method);
1616
}

tests/py-demo/bindings/src/modules/values.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44

55
#include <chrono>
66

7-
namespace {
7+
namespace myvalues {
88
class Dummy {};
99
class Foo {};
1010
} // namespace
1111

1212
void bind_values_module(py::module &&m) {
1313
{
1414
// python module as value
15-
auto &&pyDummy = py::class_<Dummy>(m, "Dummy");
15+
auto &&pyDummy = py::class_<myvalues::Dummy>(m, "Dummy");
1616

1717
pyDummy.def_property_readonly_static(
1818
"linalg", [](py::object &) { return py::module::import("numpy.linalg"); });
@@ -27,12 +27,12 @@ void bind_values_module(py::module &&m) {
2727
m.attr("list_with_none") = li;
2828
}
2929
{
30-
auto pyFoo = py::class_<Foo>(m, "Foo");
31-
m.attr("foovar") = Foo();
30+
auto pyFoo = py::class_<myvalues::Foo>(m, "Foo");
31+
m.attr("foovar") = myvalues::Foo();
3232

3333
py::list foolist;
34-
foolist.append(Foo());
35-
foolist.append(Foo());
34+
foolist.append(myvalues::Foo());
35+
foolist.append(myvalues::Foo());
3636

3737
m.attr("foolist") = foolist;
3838
m.attr("none") = py::none();

tests/stubs/python-3.12/pybind11-v3.0.0/numpy-array-use-type-var/demo/_bindings/aliases/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class Dummy:
4747
def foreign_enum_default(
4848
color: typing.Any = demo._bindings.enum.ConsoleForegroundColor.Blue,
4949
) -> None: ...
50-
def func(arg0: int) -> int: ...
50+
def func(arg0: typing.SupportsInt) -> int: ...
5151

5252
local_func_alias = func
5353
local_type_alias = Color

tests/stubs/python-3.12/pybind11-v3.0.0/numpy-array-use-type-var/demo/_bindings/classes.pyi

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ class CppException(Exception):
1313
pass
1414

1515
class Derived(Base):
16-
count: int
16+
@property
17+
def count(self) -> int: ...
18+
@count.setter
19+
def count(self, arg0: typing.SupportsInt) -> None: ...
1720

1821
class Foo:
1922
class FooChild:
@@ -43,11 +46,11 @@ class Outer:
4346
def __getstate__(self) -> int: ...
4447
def __hash__(self) -> int: ...
4548
def __index__(self) -> int: ...
46-
def __init__(self, value: int) -> None: ...
49+
def __init__(self, value: typing.SupportsInt) -> None: ...
4750
def __int__(self) -> int: ...
4851
def __ne__(self, other: typing.Any) -> bool: ...
4952
def __repr__(self) -> str: ...
50-
def __setstate__(self, state: int) -> None: ...
53+
def __setstate__(self, state: typing.SupportsInt) -> None: ...
5154
def __str__(self) -> str: ...
5255
@property
5356
def name(self) -> str: ...

tests/stubs/python-3.12/pybind11-v3.0.0/numpy-array-use-type-var/demo/_bindings/eigen.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from __future__ import annotations
33
import typing
44

55
import numpy
6+
import numpy.typing
67
import scipy.sparse
78

89
__all__: list[str] = [

0 commit comments

Comments
 (0)