diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ae8663406b..ff54b2e6d15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added `dpnp.ndarray.__contains__` method [#2534](https://github.com/IntelPython/dpnp/pull/2534) * Added implementation of `dpnp.linalg.lu_factor` (SciPy-compatible) [#2557](https://github.com/IntelPython/dpnp/pull/2557), [#2565](https://github.com/IntelPython/dpnp/pull/2565) * Added implementation of `dpnp.piecewise` [#2550](https://github.com/IntelPython/dpnp/pull/2550) +* Added implementation of `dpnp.linalg.lu_solve` for 2D inputs (SciPy-compatible) [#2575](https://github.com/IntelPython/dpnp/pull/2575) ### Changed diff --git a/doc/reference/linalg.rst b/doc/reference/linalg.rst index 107b5a86a5b..142c6052db8 100644 --- a/doc/reference/linalg.rst +++ b/doc/reference/linalg.rst @@ -86,6 +86,7 @@ Solving linear equations dpnp.linalg.solve dpnp.linalg.tensorsolve dpnp.linalg.lstsq + dpnp.linalg.lu_solve dpnp.linalg.inv dpnp.linalg.pinv dpnp.linalg.tensorinv diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 1a46205ea82..f73443229c4 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -57,6 +57,7 @@ dpnp_inv, dpnp_lstsq, dpnp_lu_factor, + dpnp_lu_solve, dpnp_matrix_power, dpnp_matrix_rank, dpnp_multi_dot, @@ -81,6 +82,7 @@ "inv", "lstsq", "lu_factor", + "lu_solve", "matmul", "matrix_norm", "matrix_power", @@ -905,7 +907,7 @@ def lstsq(a, b, rcond=None): def lu_factor(a, overwrite_a=False, check_finite=True): """ - Compute the pivoted LU decomposition of a matrix. + Compute the pivoted LU decomposition of `a` matrix. The decomposition is:: @@ -947,6 +949,11 @@ def lu_factor(a, overwrite_a=False, check_finite=True): This function synchronizes in order to validate array elements when ``check_finite=True``. + See Also + -------- + :obj:`dpnp.linalg.lu_solve` : Solve an equation system using + the LU factorization of `a` matrix. + Examples -------- >>> import dpnp as np @@ -966,6 +973,81 @@ def lu_factor(a, overwrite_a=False, check_finite=True): return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite) +def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): + """ + Solve a linear system, :math:`a x = b`, given the LU factorization of `a`. + + For full documentation refer to :obj:`scipy.linalg.lu_solve`. + + Parameters + ---------- + lu, piv : {tuple of dpnp.ndarrays or usm_ndarrays} + LU factorization of matrix `a` (M, M) together with pivot indices. + b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray} + Right-hand side + trans : {0, 1, 2} , optional + Type of system to solve: + + ===== ================= + trans system + ===== ================= + 0 :math:`a x = b` + 1 :math:`a^T x = b` + 2 :math:`a^H x = b` + ===== ================= + + Default: ``0``. + overwrite_b : {None, bool}, optional + Whether to overwrite data in `b` (may increase performance). + + Default: ``False``. + check_finite : {None, bool}, optional + Whether to check that the input matrix contains only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + + Default: ``True``. + + Returns + ------- + x : {(M,), (M, K)} dpnp.ndarray + Solution to the system + + Warning + ------- + This function synchronizes in order to validate array elements + when ``check_finite=True``. + + See Also + -------- + :obj:`dpnp.linalg.lu_factor` : LU factorize a matrix. + + Examples + -------- + >>> import dpnp as np + >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) + >>> b = np.array([1, 1, 1, 1]) + >>> lu, piv = np.linalg.lu_factor(A) + >>> x = np.linalg.lu_solve((lu, piv), b) + >>> np.allclose(A @ x - b, np.zeros((4,))) + array(True) + + """ + + (lu, piv) = lu_and_piv + dpnp.check_supported_arrays_type(lu, piv, b) + assert_stacked_2d(lu) + + return dpnp_lu_solve( + lu, + piv, + b, + trans=trans, + overwrite_b=overwrite_b, + check_finite=check_finite, + ) + + def matmul(x1, x2, /): """ Computes the matrix product. diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index bb2920e3a99..44a3816cc16 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2477,6 +2477,121 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): return (a_h, ipiv_h) +def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): + """ + dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True) + + Solve an equation system (SciPy-compatible behavior). + + This function mimics the behavior of `scipy.linalg.lu_solve` including + support for `trans`, `overwrite_b`, `check_finite`, + and 0-based pivot indexing. + + """ + + res_usm_type, exec_q = get_usm_allocations([lu, piv, b]) + + res_type = _common_type(lu, b) + + # TODO: add broadcasting + if lu.shape[0] != b.shape[0]: + raise ValueError( + f"Shapes of lu {lu.shape} and b {b.shape} are incompatible" + ) + + if b.size == 0: + return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type) + + if lu.ndim > 2: + raise NotImplementedError("Batched matrices are not supported") + + if check_finite: + if not dpnp.isfinite(lu).all(): + raise ValueError( + "LU factorization array must not contain infs or NaNs.\n" + "Note that when a singular matrix is given, unlike " + "dpnp.linalg.lu_factor returns an array containing NaN." + ) + if not dpnp.isfinite(b).all(): + raise ValueError( + "Right-hand side array must not contain infs or NaNs" + ) + + lu_usm_arr = dpnp.get_usm_ndarray(lu) + b_usm_arr = dpnp.get_usm_ndarray(b) + + # dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy, + # convert to 1-based for oneMKL getrs + piv_h = piv + 1 + + _manager = dpu.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + # oneMKL LAPACK getrs overwrites `lu`. + lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) + + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=lu_usm_arr, + dst=lu_h.get_array(), + sycl_queue=lu.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, lu_copy_ev) + + # SciPy-compatible behavior + # Copy is required if: + # - overwrite_b is False (always copy), + # - dtype mismatch, + # - not F-contiguous, + # - not writeable + if not overwrite_b or _is_copy_required(b, res_type): + b_h = dpnp.empty_like( + b, order="F", dtype=res_type, usm_type=res_usm_type + ) + ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr, + dst=b_h.get_array(), + sycl_queue=b.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, b_copy_ev) + dep_evs = [lu_copy_ev, b_copy_ev] + else: + # input is suitable for in-place modification + b_h = b + dep_evs = [lu_copy_ev] + + if not isinstance(trans, int): + raise TypeError("`trans` must be an integer") + + # Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums + if trans == 0: + trans_mkl = li.Transpose.N + elif trans == 1: + trans_mkl = li.Transpose.T + elif trans == 2: + trans_mkl = li.Transpose.C + else: + raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)") + + # Call the LAPACK extension function _getrs + # to solve the system of linear equations with an LU-factored + # coefficient square matrix, with multiple right-hand sides. + ht_ev, getrs_ev = li._getrs( + exec_q, + lu_h.get_array(), + piv_h.get_array(), + b_h.get_array(), + trans_mkl, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, getrs_ev) + + return b_h + + def dpnp_matrix_power(a, n): """ dpnp_matrix_power(a, n) diff --git a/dpnp/tests/helper.py b/dpnp/tests/helper.py index edb077a161c..93146159b11 100644 --- a/dpnp/tests/helper.py +++ b/dpnp/tests/helper.py @@ -1,3 +1,4 @@ +import importlib.util from sys import platform import dpctl @@ -488,6 +489,14 @@ def is_ptl(device=None): return _get_dev_mask(device) in (0xB000, 0xFD00) +def is_scipy_available(): + """ + Return True if SciPy is installed and can be found, + False otherwise. + """ + return importlib.util.find_spec("scipy") is not None + + def is_tgllp_iris_xe(device=None): """ Return True if a test is running on Tiger Lake-LP with Iris Xe GPU device, diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index 97379876c90..a25c237f846 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -1931,6 +1931,7 @@ def test_overwrite_inplace(self, dtype): ) assert lu is a_dp + assert lu.data.ptr == a_dp.data.ptr assert lu.flags["F_CONTIGUOUS"] is True L, U = self._split_lu(lu, 2, 2) @@ -1948,6 +1949,7 @@ def test_overwrite_copy(self, dtype): ) assert lu is not a_dp + assert lu.data.ptr != a_dp.data.ptr assert lu.flags["F_CONTIGUOUS"] is True L, U = self._split_lu(lu, 2, 2) @@ -1974,6 +1976,7 @@ def test_overwrite_copy_special(self): ) assert lu is not a_dp + assert lu.data.ptr != a_dp.data.ptr assert lu.flags["F_CONTIGUOUS"] is True L, U = self._split_lu(lu, 2, 2) @@ -2144,6 +2147,217 @@ def test_check_finite_raises(self): assert_raises(ValueError, dpnp.linalg.lu_factor, a, check_finite=True) +class TestLuSolve: + @staticmethod + def _make_nonsingular_np(shape, dtype, order): + A = generate_random_numpy_array(shape, dtype, order) + m, n = shape + k = min(m, n) + for i in range(k): + off = numpy.sum(numpy.abs(A[i, :n])) - numpy.abs(A[i, i]) + A[i, i] = A.dtype.type(off + 1.0) + return A + + @pytest.mark.parametrize("shape", [(1, 1), (2, 2), (3, 3), (5, 5)]) + @pytest.mark.parametrize("rhs_cols", [None, 1, 3]) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_bool=True, no_none=True) + ) + def test_lu_solve(self, shape, rhs_cols, order, dtype): + a_np = self._make_nonsingular_np(shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + n = shape[0] + if rhs_cols is None: + b_np = generate_random_numpy_array((n,), dtype, order) + else: + b_np = generate_random_numpy_array((n, rhs_cols), dtype, order) + b_dp = dpnp.array(b_np, order=order) + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, trans=0, overwrite_b=False, check_finite=False + ) + + # check A @ x = b + Ax = a_dp @ x + assert dpnp.allclose(Ax, b_dp, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("trans", [0, 1, 2]) + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_trans(self, trans, dtype): + n = 4 + a_np = self._make_nonsingular_np((n, n), dtype, order="F") + a_dp = dpnp.array(a_np, order="F") + b_dp = dpnp.array(generate_random_numpy_array((n, 2), dtype, "F")) + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, trans=trans, overwrite_b=False, check_finite=False + ) + + if trans == 0: + lhs = a_dp @ x + elif trans == 1: + lhs = a_dp.T @ x + else: # trans == 2 + lhs = a_dp.conj().T @ x + + assert dpnp.allclose(lhs, b_dp, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_overwrite_inplace(self, dtype): + a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F") + b_dp = dpnp.array([1, 0], dtype=dtype, order="F") + b_orig = b_dp.copy() + + lu, piv = dpnp.linalg.lu_factor( + a_dp, overwrite_a=False, check_finite=False + ) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, trans=0, overwrite_b=True, check_finite=False + ) + + assert x is b_dp + assert x.data.ptr == b_dp.data.ptr + assert dpnp.allclose(a_dp @ x, b_orig, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_overwrite_copy_special(self, dtype): + a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F") + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + + # F-contig but dtype != res_type + b1 = dpnp.array([1, 0], dtype=dpnp.int32, order="F") + x1 = dpnp.linalg.lu_solve( + (lu, piv), b1, overwrite_b=True, check_finite=False + ) + assert x1 is not b1 + assert x1.data.ptr != b1.data.ptr + + # F-contig, match dtype but read-only input + b2 = dpnp.array([1, 0], dtype=dtype, order="F") + b2.flags["WRITABLE"] = False + x2 = dpnp.linalg.lu_solve( + (lu, piv), b2, overwrite_b=True, check_finite=False + ) + assert x2 is not b2 + assert x2.data.ptr != b2.data.ptr + + for x in (x1, x2): + assert dpnp.allclose( + a_dp @ x, + dpnp.array([1, 0], dtype=x.dtype), + rtol=1e-6, + atol=1e-6, + ) + + @pytest.mark.parametrize( + "dtype_a", get_all_dtypes(no_bool=True, no_none=True) + ) + @pytest.mark.parametrize( + "dtype_b", get_all_dtypes(no_bool=True, no_none=True) + ) + def test_diff_type(self, dtype_a, dtype_b): + a_np = self._make_nonsingular_np((3, 3), dtype_a, order="F") + a_dp = dpnp.array(a_np, order="F") + + b_np = generate_random_numpy_array((3,), dtype_b, order="F") + b_dp = dpnp.array(b_np, order="F") + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False) + assert dpnp.allclose( + a_dp @ x, b_dp.astype(x.dtype, copy=False), rtol=1e-5, atol=1e-5 + ) + + def test_strided_rhs(self): + n = 7 + a_np = self._make_nonsingular_np( + (n, n), dpnp.default_float_type(), order="F" + ) + a_dp = dpnp.array(a_np, order="F") + + rhs_full = ( + dpnp.arange(n * n, dtype=dpnp.default_float_type()).reshape( + n, n, order="F" + ) + + 1.0 + ) + b_dp = rhs_full[:, ::2][:, :3] + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, overwrite_b=False, check_finite=False + ) + + assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize( + "b_shape", + [ + (4,), + (4, 1), + (4, 3), + # (1, 4, 3), + # (2, 4, 3), + # (1, 1, 4, 3) + ], + ) + def test_broadcast_rhs(self, b_shape): + dtype = dpnp.default_float_type() + + a_np = self._make_nonsingular_np((4, 4), dtype, order="F") + a_dp = dpnp.array(a_np, order="F") + + b_np = generate_random_numpy_array(b_shape, dtype, order="F") + b_dp = dpnp.array(b_np, order="F") + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, overwrite_b=False, check_finite=False + ) + + assert x.shape == b_dp.shape + + assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("shape", [(0, 0), (0, 5), (5, 5)]) + @pytest.mark.parametrize("rhs_cols", [None, 0, 3]) + def test_empty_shapes(self, shape, rhs_cols): + a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") + if min(shape) > 0: + for i in range(min(shape)): + a_dp[i, i] = a_dp.dtype.type(1.0) + + n = shape[0] + if rhs_cols is None: + b_shape = (n,) + else: + b_shape = (n, rhs_cols) + b_dp = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F") + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False) + + assert x.shape == b_shape + + @pytest.mark.parametrize("bad", [numpy.inf, -numpy.inf, numpy.nan]) + def test_check_finite_raises(self, bad): + a_dp = dpnp.array([[1.0, 0.0], [0.0, 1.0]], order="F") + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + + b_bad = dpnp.array([1.0, bad], order="F") + assert_raises( + ValueError, + dpnp.linalg.lu_solve, + (lu, piv), + b_bad, + check_finite=True, + ) + + class TestMatrixPower: @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index 413925f8d55..f4299325762 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -9,6 +9,7 @@ from numpy.testing import assert_array_equal, assert_raises import dpnp +import dpnp.linalg from dpnp.dpnp_array import dpnp_array from dpnp.dpnp_utils import get_usm_allocations @@ -1610,6 +1611,20 @@ def test_lu_factor(self, data, device): param_queue = param.sycl_queue assert_sycl_queue_equal(param_queue, a.sycl_queue) + @pytest.mark.parametrize( + "b_data", + [[1.0, 2.0], numpy.empty((2, 0))], + ) + def test_lu_solve(self, b_data, device): + a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device) + lu, piv = dpnp.linalg.lu_factor(a) + b = dpnp.array(b_data, device=device) + + result = dpnp.linalg.lu_solve((lu, piv), b) + + assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue) + assert_sycl_queue_equal(result.sycl_queue, b.sycl_queue) + @pytest.mark.parametrize("n", [-1, 0, 1, 2, 3]) def test_matrix_power(self, n, device): x = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device) diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index a945599fe3a..c17526649ab 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -1487,6 +1487,24 @@ def test_lu_factor(self, data, usm_type): for param in result: assert param.usm_type == a.usm_type + @pytest.mark.parametrize("usm_type_rhs", list_of_usm_types) + @pytest.mark.parametrize( + "b_data", + [[1.0, 2.0], numpy.empty((2, 0))], + ) + def test_lu_solve(self, b_data, usm_type, usm_type_rhs): + a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], usm_type=usm_type) + lu, piv = dpnp.linalg.lu_factor(a) + b = dpnp.array(b_data, usm_type=usm_type_rhs) + + result = dpnp.linalg.lu_solve((lu, piv), b) + + assert lu.usm_type == usm_type + assert b.usm_type == usm_type_rhs + assert result.usm_type == du.get_coerced_usm_type( + [usm_type, usm_type_rhs] + ) + @pytest.mark.parametrize("n", [-1, 0, 1, 2, 3]) def test_matrix_power(self, n, usm_type): a = dpnp.array([[1, 2], [3, 5]], usm_type=usm_type) diff --git a/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/__init__.py b/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/test_decomp_lu.py b/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/test_decomp_lu.py new file mode 100644 index 00000000000..2e0da004413 --- /dev/null +++ b/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/test_decomp_lu.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import unittest +import warnings + +import numpy +import pytest + +import dpnp as cupy +from dpnp.tests.third_party.cupy import testing + +if cupy.tests.helper.is_scipy_available(): + import scipy.linalg + + +# TODO: After the feature is released +# requires_scipy_linalg_backend = testing.with_requires('scipy>=1.x.x') +requires_scipy_linalg_backend = unittest.skip( + "scipy.linalg backend feature has not been released" +) + + +@testing.parameterize( + *testing.product( + { + "shape": [ + (1, 1), + (2, 2), + (3, 3), + (5, 5), + (1, 5), + (5, 1), + (2, 5), + (5, 2), + ], + } + ) +) +@testing.fix_random() +@testing.with_requires("scipy") +class TestLUFactor(unittest.TestCase): + + @testing.for_dtypes("fdFD") + def test_lu_factor(self, dtype): + if self.shape[0] != self.shape[1]: + self.skipTest( + "skip non-square tests since scipy.lu_factor requires square" + ) + a_cpu = testing.shaped_random(self.shape, numpy, dtype=dtype) + a_gpu = cupy.asarray(a_cpu) + result_cpu = scipy.linalg.lu_factor(a_cpu) + # Originally used cupyx.scipy.linalg.lu_factor + result_gpu = cupy.linalg.lu_factor(a_gpu) + assert len(result_cpu) == len(result_gpu) + assert result_cpu[0].dtype == result_gpu[0].dtype + # DPNP returns pivot indices as int64, while SciPy returns int32. + # Check for the expected dtypes explicitly. + # assert result_cpu[1].dtype == result_gpu[1].dtype + assert result_cpu[1].dtype == cupy.int32 + assert result_gpu[1].dtype == cupy.int64 + testing.assert_allclose(result_cpu[0], result_gpu[0], atol=1e-5) + testing.assert_array_equal(result_cpu[1], result_gpu[1]) + + def check_lu_factor_reconstruction(self, A): + m, n = self.shape + lu, piv = cupy.linalg.lu_factor(A) + # extract ``L`` and ``U`` from ``lu`` + L = cupy.tril(lu, k=-1) + cupy.fill_diagonal(L, 1.0) + L = L[:, :m] + U = cupy.triu(lu) + U = U[:n, :] + # check output shapes + assert lu.shape == (m, n) + assert L.shape == (m, min(m, n)) + assert U.shape == (min(m, n), n) + assert piv.shape == (min(m, n),) + # apply pivot (on CPU since slaswp is not available in cupy) + piv = cupy.asnumpy(piv) + rows = list(range(m)) + for i, row in enumerate(piv): + if i != row: + rows[i], rows[row] = rows[row], rows[i] + rows = cupy.asarray(rows) + PA = A[rows] + # check that reconstruction is close to original + LU = L.dot(U) + testing.assert_allclose(LU, PA, atol=1e-5) + + @testing.for_dtypes("fdFD") + def test_lu_factor_reconstruction(self, dtype): + A = testing.shaped_random(self.shape, cupy, dtype=dtype) + self.check_lu_factor_reconstruction(A) + + @testing.for_dtypes("fdFD") + def test_lu_factor_reconstruction_singular(self, dtype): + if self.shape[0] != self.shape[1]: + self.skipTest( + "skip non-square tests since scipy.lu_factor requires square" + ) + A = testing.shaped_random(self.shape, cupy, dtype=dtype) + A -= A.mean(axis=0, keepdims=True) + A -= A.mean(axis=1, keepdims=True) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + self.check_lu_factor_reconstruction(A) + + +@testing.parameterize( + *testing.product( + { + "shape": [ + (1, 1), + (2, 2), + (3, 3), + (5, 5), + (1, 5), + (5, 1), + (2, 5), + (5, 2), + ], + "permute_l": [False, True], + } + ) +) +@testing.fix_random() +@testing.with_requires("scipy") +@pytest.mark.skip("lu() is not supported yet") +class TestLU(unittest.TestCase): + + @testing.for_dtypes("fdFD") + def test_lu(self, dtype): + a_cpu = testing.shaped_random(self.shape, numpy, dtype=dtype) + a_gpu = cupy.asarray(a_cpu) + result_cpu = scipy.linalg.lu(a_cpu, permute_l=self.permute_l) + result_gpu = cupy.linalg.lu(a_gpu, permute_l=self.permute_l) + assert len(result_cpu) == len(result_gpu) + if not self.permute_l: + # check permutation matrix + result_cpu = list(result_cpu) + result_gpu = list(result_gpu) + P_cpu = result_cpu.pop(0) + P_gpu = result_gpu.pop(0) + cupy.testing.assert_array_equal(P_gpu, P_cpu) + cupy.testing.assert_allclose(result_gpu[0], result_cpu[0], atol=1e-5) + cupy.testing.assert_allclose(result_gpu[1], result_cpu[1], atol=1e-5) + + @testing.for_dtypes("fdFD") + def test_lu_reconstruction(self, dtype): + m, n = self.shape + A = testing.shaped_random(self.shape, cupy, dtype=dtype) + if self.permute_l: + PL, U = cupy.linalg.lu(A, permute_l=self.permute_l) + PLU = PL @ U + else: + P, L, U = cupy.linalg.lu(A, permute_l=self.permute_l) + PLU = P @ L @ U + # check that reconstruction is close to original + cupy.testing.assert_allclose(PLU, A, atol=1e-5) + + +@testing.parameterize( + *testing.product( + { + "trans": [0, 1, 2], + "shapes": [((4, 4), (4,)), ((5, 5), (5, 2))], + } + ) +) +@testing.fix_random() +@testing.with_requires("scipy") +class TestLUSolve(unittest.TestCase): + + @testing.for_dtypes("fdFD") + @testing.numpy_cupy_allclose(atol=1e-5, scipy_name="scp") + def test_lu_solve(self, xp, scp, dtype): + a_shape, b_shape = self.shapes + A = testing.shaped_random(a_shape, xp, dtype=dtype) + b = testing.shaped_random(b_shape, xp, dtype=dtype) + lu = scp.linalg.lu_factor(A) + return scp.linalg.lu_solve(lu, b, trans=self.trans) + + @requires_scipy_linalg_backend + @testing.for_dtypes("fdFD") + @testing.numpy_cupy_allclose(atol=1e-5) + def test_lu_solve_backend(self, xp, dtype): + a_shape, b_shape = self.shapes + A = testing.shaped_random(a_shape, xp, dtype=dtype) + b = testing.shaped_random(b_shape, xp, dtype=dtype) + if xp is numpy: + lu = scipy.linalg.lu_factor(A) + backend = "scipy" + else: + lu = cupy.linalg.lu_factor(A) + backend = cupy.linalg + with scipy.linalg.set_backend(backend): + out = scipy.linalg.lu_solve(lu, b, trans=self.trans) + return out