40
40
41
41
42
42
class RemoveSelfAnnotation (IParser ):
43
-
44
43
__any_t_name = QualifiedName .from_str ("Any" )
45
44
__typing_any_t_name = QualifiedName .from_str ("typing.Any" )
46
45
@@ -632,10 +631,19 @@ def report_error(self, error: ParserError) -> None:
632
631
633
632
634
633
class FixNumpyArrayDimTypeVar (IParser ):
635
- __array_names : set [QualifiedName ] = {QualifiedName .from_str ("numpy.ndarray" )}
634
+ __array_names : set [QualifiedName ] = {
635
+ QualifiedName .from_str ("numpy.ndarray" ),
636
+ QualifiedName .from_str ("numpy.typing.ArrayLike" ),
637
+ QualifiedName .from_str ("numpy.typing.NDArray" ),
638
+ }
639
+ __typing_annotated_names = {
640
+ QualifiedName .from_str ("typing.Annotated" ),
641
+ QualifiedName .from_str ("typing_extensions.Annotated" ),
642
+ }
636
643
numpy_primitive_types = FixNumpyArrayDimAnnotation .numpy_primitive_types
637
644
638
645
__DIM_VARS : set [str ] = set ()
646
+ __DIM_STRING_PATTERN = re .compile (r'"\[(.*?)\]"' )
639
647
640
648
def handle_module (
641
649
self , path : QualifiedName , module : types .ModuleType
@@ -659,85 +667,155 @@ def handle_module(
659
667
)
660
668
661
669
self .__DIM_VARS .clear ()
662
-
663
670
return result
664
671
665
672
def parse_annotation_str (
666
673
self , annotation_str : str
667
674
) -> ResolvedType | InvalidExpression | Value :
668
- # Affects types of the following pattern:
669
- # numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
670
- # Replace with:
671
- # numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]
672
-
673
675
result = super ().parse_annotation_str (annotation_str )
674
-
675
676
if not isinstance (result , ResolvedType ):
676
677
return result
677
678
678
679
# handle unqualified, single-letter annotation as a TypeVar
679
680
if len (result .name ) == 1 and len (result .name [0 ]) == 1 :
680
681
result .name = QualifiedName .from_str (result .name [0 ].upper ())
681
682
self .__DIM_VARS .add (result .name [0 ])
683
+ return result
682
684
683
- if result .name not in self .__array_names :
685
+ if result .name == QualifiedName .from_str ("numpy.ndarray" ):
686
+ parameters = self ._handle_old_style_numpy_array (result .parameters )
687
+ elif result .name in self .__array_names :
688
+ parameters = self ._handle_new_style_numpy_array ([result ])
689
+ elif result .name in self .__typing_annotated_names :
690
+ parameters = self ._handle_new_style_numpy_array (result .parameters )
691
+ else :
692
+ parameters = None
693
+ if parameters is None : # Failure.
684
694
return result
695
+ return ResolvedType (
696
+ name = QualifiedName .from_str ("numpy.ndarray" ), parameters = parameters
697
+ )
698
+
699
+ def _process_numpy_array_type (
700
+ self , scalar_type_name : QualifiedName , dimensions : list [int | str ] | None
701
+ ) -> tuple [ResolvedType , ResolvedType ]:
702
+ # Pybind annotates a bool Python type, which cannot be used with
703
+ # numpy.dtype because it does not inherit from numpy.generic.
704
+ # Only numpy.bool_ works reliably with both NumPy 1.x and 2.x.
705
+ if str (scalar_type_name ) == "bool" :
706
+ scalar_type_name = QualifiedName .from_str ("numpy.bool_" )
707
+ dtype = ResolvedType (
708
+ name = QualifiedName .from_str ("numpy.dtype" ),
709
+ parameters = [ResolvedType (name = scalar_type_name )],
710
+ )
711
+
712
+ shape = self .parse_annotation_str ("Any" )
713
+ if dimensions is not None and len (dimensions ) > 0 :
714
+ shape = self .parse_annotation_str ("Tuple" )
715
+ assert isinstance (shape , ResolvedType )
716
+ shape .parameters = []
717
+ for dim in dimensions :
718
+ if isinstance (dim , int ):
719
+ literal_dim = self .parse_annotation_str ("Literal" )
720
+ assert isinstance (literal_dim , ResolvedType )
721
+ literal_dim .parameters = [Value (repr = str (dim ))]
722
+ shape .parameters .append (literal_dim )
723
+ else :
724
+ shape .parameters .append (
725
+ ResolvedType (name = QualifiedName .from_str (dim .upper ()))
726
+ )
727
+ return shape , dtype
728
+
729
+ def _handle_new_style_numpy_array (
730
+ self , parameters : list [ResolvedType | Value | InvalidExpression ] | None
731
+ ) -> list [ResolvedType ] | None :
732
+ # Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, n]"]
733
+ # Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]"]
734
+ # Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]", "flags.writeable", "flags.c_contiguous"]
735
+ if parameters is None or len (parameters ) == 0 :
736
+ return
737
+
738
+ array_type , * parameters = parameters
739
+ if (
740
+ not isinstance (array_type , ResolvedType )
741
+ or array_type .name not in self .__array_names
742
+ ):
743
+ return
744
+
745
+ dims_and_flags : Sequence [ResolvedType | Value | InvalidExpression ]
746
+ if array_type .name == QualifiedName .from_str ("numpy.typing.ArrayLike" ):
747
+ if not parameters :
748
+ return
749
+ scalar_type , * dims_and_flags = parameters
750
+ elif array_type .name == QualifiedName .from_str ("numpy.typing.NDArray" ):
751
+ if array_type .parameters is None or len (array_type .parameters ) == 0 :
752
+ return
753
+ [scalar_type ] = array_type .parameters
754
+ dims_and_flags = parameters
755
+ elif array_type .name == QualifiedName .from_str ("numpy.ndarray" ):
756
+ _ , dtype_param = array_type .parameters
757
+ if not (
758
+ isinstance (dtype_param , ResolvedType )
759
+ and dtype_param .name == QualifiedName .from_str ("numpy.dtype" )
760
+ and dtype_param .parameters
761
+ ):
762
+ return
763
+ [scalar_type ] = dtype_param .parameters
764
+ dims_and_flags = parameters
765
+ else :
766
+ return
767
+ scalar_type_name = scalar_type .name
768
+ if scalar_type_name not in self .numpy_primitive_types :
769
+ return
770
+
771
+ dims : list [int | str ] | None = None
772
+ if dims_and_flags :
773
+ dims_str , * flags = dims_and_flags
774
+ del flags # Unused.
775
+ if isinstance (dims_str , Value ):
776
+ match = self .__DIM_STRING_PATTERN .search (dims_str .repr )
777
+ if match :
778
+ dims_str_content = match .group (1 )
779
+ dims_list = [
780
+ d .strip () for d in dims_str_content .split ("," ) if d .strip ()
781
+ ]
782
+ if dims_list :
783
+ dims = self .__to_dims_from_strings (dims_list )
784
+
785
+ return self ._process_numpy_array_type (scalar_type_name , dims )
786
+
787
+ def _handle_old_style_numpy_array (
788
+ self , parameters : list [ResolvedType | Value | InvalidExpression ] | None
789
+ ) -> list [ResolvedType ] | None :
790
+ # Affects types of the following pattern:
791
+ # numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
792
+ # Replace with:
793
+ # numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]
685
794
686
795
# ndarray is generic and should have 2 type arguments
687
- if result . parameters is None or len (result . parameters ) == 0 :
688
- result . parameters = [
796
+ if parameters is None or len (parameters ) == 0 :
797
+ return [
689
798
self .parse_annotation_str ("Any" ),
690
799
ResolvedType (
691
800
name = QualifiedName .from_str ("numpy.dtype" ),
692
801
parameters = [self .parse_annotation_str ("Any" )],
693
802
),
694
803
]
695
- return result
696
-
697
- scalar_with_dims = result .parameters [0 ] # e.g. numpy.float64[32, 32]
698
804
805
+ scalar_with_dims = parameters [0 ] # e.g. numpy.float64[32, 32]
699
806
if (
700
807
not isinstance (scalar_with_dims , ResolvedType )
701
808
or scalar_with_dims .name not in self .numpy_primitive_types
702
809
):
703
- return result
704
-
705
- name = scalar_with_dims .name
706
- # Pybind annotates a bool Python type, which cannot be used with
707
- # numpy.dtype because it does not inherit from numpy.generic.
708
- # Only numpy.bool_ works reliably with both NumPy 1.x and 2.x.
709
- if str (name ) == "bool" :
710
- name = QualifiedName .from_str ("numpy.bool_" )
711
- dtype = ResolvedType (
712
- name = QualifiedName .from_str ("numpy.dtype" ),
713
- parameters = [ResolvedType (name = name )],
714
- )
810
+ return
715
811
716
- shape = self . parse_annotation_str ( "Any" )
812
+ dims : list [ int | str ] | None = None
717
813
if (
718
814
scalar_with_dims .parameters is not None
719
815
and len (scalar_with_dims .parameters ) > 0
720
816
):
721
817
dims = self .__to_dims (scalar_with_dims .parameters )
722
- if dims is not None :
723
- shape = self .parse_annotation_str ("Tuple" )
724
- assert isinstance (shape , ResolvedType )
725
- shape .parameters = []
726
- for dim in dims :
727
- if isinstance (dim , int ):
728
- # self.parse_annotation_str will qualify Literal with either
729
- # typing or typing_extensions and add the import to the module
730
- literal_dim = self .parse_annotation_str ("Literal" )
731
- assert isinstance (literal_dim , ResolvedType )
732
- literal_dim .parameters = [Value (repr = str (dim ))]
733
- shape .parameters .append (literal_dim )
734
- else :
735
- shape .parameters .append (
736
- ResolvedType (name = QualifiedName .from_str (dim ))
737
- )
738
-
739
- result .parameters = [shape , dtype ]
740
- return result
818
+ return self ._process_numpy_array_type (scalar_with_dims .name , dims )
741
819
742
820
def __to_dims (
743
821
self , dimensions : Sequence [ResolvedType | Value | InvalidExpression ]
@@ -756,6 +834,20 @@ def __to_dims(
756
834
result .append (dim )
757
835
return result
758
836
837
+ def __to_dims_from_strings (
838
+ self , dimensions : Sequence [str ]
839
+ ) -> list [int | str ] | None :
840
+ result : list [int | str ] = []
841
+ for dim_str in dimensions :
842
+ try :
843
+ dim = int (dim_str )
844
+ except ValueError :
845
+ dim = dim_str
846
+ if len (dim ) == 1 : # Assuming single letter dims are type vars
847
+ self .__DIM_VARS .add (dim .upper ()) # Add uppercase to TypeVar set
848
+ result .append (dim )
849
+ return result
850
+
759
851
def report_error (self , error : ParserError ) -> None :
760
852
if (
761
853
isinstance (error , NameResolutionError )
0 commit comments