1- from numpy .linalg import * # noqa: F403
2- from numpy .linalg import __all__ as linalg_all
3- import numpy as _np
1+ # pyright: reportAttributeAccessIssue=false
2+ # pyright: reportUnknownArgumentType=false
3+ # pyright: reportUnknownMemberType=false
4+ # pyright: reportUnknownVariableType=false
5+
6+ from __future__ import annotations
7+
8+ import numpy as np
9+
10+ # intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__`
11+ from numpy .linalg import (
12+ LinAlgError ,
13+ cond ,
14+ det ,
15+ eig ,
16+ eigvals ,
17+ eigvalsh ,
18+ inv ,
19+ lstsq ,
20+ matrix_power ,
21+ multi_dot ,
22+ norm ,
23+ tensorinv ,
24+ tensorsolve ,
25+ )
426
5- from ..common import _linalg
627from .._internal import get_xp
28+ from ..common import _linalg
729
830# These functions are in both the main and linalg namespaces
9- from ._aliases import matmul , matrix_transpose , tensordot , vecdot # noqa: F401
10-
11- import numpy as np
31+ from ._aliases import matmul , matrix_transpose , tensordot , vecdot # noqa: F401
32+ from ._typing import Array
1233
1334cross = get_xp (np )(_linalg .cross )
1435outer = get_xp (np )(_linalg .outer )
3859# To workaround this, the below is the code from np.linalg.solve except
3960# only calling solve1 in the exactly 1D case.
4061
62+
4163# This code is here instead of in common because it is numpy specific. Also
4264# note that CuPy's solve() does not currently support broadcasting (see
4365# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
44- def solve (x1 : _np . ndarray , x2 : _np . ndarray , / ) -> _np . ndarray :
66+ def solve (x1 : Array , x2 : Array , / ) -> Array :
4567 try :
4668 from numpy .linalg ._linalg import (
47- _makearray , _assert_stacked_2d , _assert_stacked_square ,
48- _commonType , isComplexType , _raise_linalgerror_singular
69+ _assert_stacked_2d ,
70+ _assert_stacked_square ,
71+ _commonType ,
72+ _makearray ,
73+ _raise_linalgerror_singular ,
74+ isComplexType ,
4975 )
5076 except ImportError :
5177 from numpy .linalg .linalg import (
52- _makearray , _assert_stacked_2d , _assert_stacked_square ,
53- _commonType , isComplexType , _raise_linalgerror_singular
78+ _assert_stacked_2d ,
79+ _assert_stacked_square ,
80+ _commonType ,
81+ _makearray ,
82+ _raise_linalgerror_singular ,
83+ isComplexType ,
5484 )
5585 from numpy .linalg import _umath_linalg
5686
@@ -61,30 +91,53 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
6191 t , result_t = _commonType (x1 , x2 )
6292
6393 # This part is different from np.linalg.solve
94+ gufunc : np .ufunc
6495 if x2 .ndim == 1 :
6596 gufunc = _umath_linalg .solve1
6697 else :
6798 gufunc = _umath_linalg .solve
6899
69100 # This does nothing currently but is left in because it will be relevant
70101 # when complex dtype support is added to the spec in 2022.
71- signature = 'DD->D' if isComplexType (t ) else 'dd->d'
72- with _np .errstate (call = _raise_linalgerror_singular , invalid = 'call' ,
73- over = 'ignore' , divide = 'ignore' , under = 'ignore' ):
74- r = gufunc (x1 , x2 , signature = signature )
102+ signature = "DD->D" if isComplexType (t ) else "dd->d"
103+ with np .errstate (
104+ call = _raise_linalgerror_singular ,
105+ invalid = "call" ,
106+ over = "ignore" ,
107+ divide = "ignore" ,
108+ under = "ignore" ,
109+ ):
110+ r : Array = gufunc (x1 , x2 , signature = signature )
75111
76112 return wrap (r .astype (result_t , copy = False ))
77113
114+
78115# These functions are completely new here. If the library already has them
79116# (i.e., numpy 2.0), use the library version instead of our wrapper.
80- if hasattr (np .linalg , ' vector_norm' ):
117+ if hasattr (np .linalg , " vector_norm" ):
81118 vector_norm = np .linalg .vector_norm
82119else :
83120 vector_norm = get_xp (np )(_linalg .vector_norm )
84121
85- __all__ = linalg_all + _linalg .__all__ + ['solve' ]
86122
87- del get_xp
88- del np
89- del linalg_all
90- del _linalg
123+ __all__ = [
124+ "LinAlgError" ,
125+ "cond" ,
126+ "det" ,
127+ "eig" ,
128+ "eigvals" ,
129+ "eigvalsh" ,
130+ "inv" ,
131+ "lstsq" ,
132+ "matrix_power" ,
133+ "multi_dot" ,
134+ "norm" ,
135+ "tensorinv" ,
136+ "tensorsolve" ,
137+ ]
138+ __all__ += _linalg .__all__
139+ __all__ += ["solve" , "vector_norm" ]
140+
141+
142+ def __dir__ () -> list [str ]:
143+ return __all__
0 commit comments