Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4223c72
Pass trans_code to getrs in dpnp_solve()
vlad-perevezentsev Aug 18, 2025
80ce50c
Remove TODO
vlad-perevezentsev Sep 4, 2025
af0ab7d
Implement of dpnp.linalg.lu_solve for 2D inputs
vlad-perevezentsev Sep 4, 2025
17b11ae
Add dpnp.linalg.lu_solve to generated docs
vlad-perevezentsev Sep 4, 2025
b10a8d6
Add TestLuSolve to test_linalg.py
vlad-perevezentsev Sep 4, 2025
2021f77
Add sycl_queue and usm_type tests
vlad-perevezentsev Sep 4, 2025
be2725a
Update doc/comment lines
vlad-perevezentsev Sep 16, 2025
1e09cb7
Update dependency logic
vlad-perevezentsev Sep 18, 2025
9345b7b
Add trans code handling
vlad-perevezentsev Sep 18, 2025
687006f
Fix docs for lu:must be square
vlad-perevezentsev Sep 18, 2025
b1aed58
Merge master into impl_lu_solve_2D
vlad-perevezentsev Sep 18, 2025
9aaff82
Update changelog
vlad-perevezentsev Sep 18, 2025
23ad15d
Apply docs remarks
vlad-perevezentsev Sep 19, 2025
82de136
Apply remarks
vlad-perevezentsev Sep 19, 2025
e586075
Add assert on USM data pointer to tests
vlad-perevezentsev Sep 19, 2025
7d1fd0b
Update data inputs for test_usm_type
vlad-perevezentsev Sep 19, 2025
87074fa
Add See Also to lu_factor
vlad-perevezentsev Sep 22, 2025
d81454e
Enable cupyx tests
vlad-perevezentsev Sep 22, 2025
4c9afa9
Merge master into impl_lu_solve_2D
vlad-perevezentsev Sep 22, 2025
78a4c78
Adjust tolerance for test_lu_solve
vlad-perevezentsev Sep 22, 2025
27649a7
Merge master into impl_lu_solve_2D
vlad-perevezentsev Sep 22, 2025
d0fbd49
Apply remark
vlad-perevezentsev Sep 22, 2025
52eac3d
Adjust tolerance for interger dtypes
vlad-perevezentsev Sep 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Added `dpnp.ndarray.__contains__` method [#2534](https://github.com/IntelPython/dpnp/pull/2534)
* Added implementation of `dpnp.linalg.lu_factor` (SciPy-compatible) [#2557](https://github.com/IntelPython/dpnp/pull/2557), [#2565](https://github.com/IntelPython/dpnp/pull/2565)
* Added implementation of `dpnp.piecewise` [#2550](https://github.com/IntelPython/dpnp/pull/2550)
* Added implementation of `dpnp.linalg.lu_solve` for 2D inputs (SciPy-compatible) [#2575](https://github.com/IntelPython/dpnp/pull/2575)

### Changed

Expand Down
1 change: 1 addition & 0 deletions doc/reference/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Solving linear equations
dpnp.linalg.solve
dpnp.linalg.tensorsolve
dpnp.linalg.lstsq
dpnp.linalg.lu_solve
dpnp.linalg.inv
dpnp.linalg.pinv
dpnp.linalg.tensorinv
Expand Down
84 changes: 83 additions & 1 deletion dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
dpnp_inv,
dpnp_lstsq,
dpnp_lu_factor,
dpnp_lu_solve,
dpnp_matrix_power,
dpnp_matrix_rank,
dpnp_multi_dot,
Expand All @@ -81,6 +82,7 @@
"inv",
"lstsq",
"lu_factor",
"lu_solve",
"matmul",
"matrix_norm",
"matrix_power",
Expand Down Expand Up @@ -905,7 +907,7 @@ def lstsq(a, b, rcond=None):

def lu_factor(a, overwrite_a=False, check_finite=True):
"""
Compute the pivoted LU decomposition of a matrix.
Compute the pivoted LU decomposition of `a` matrix.

The decomposition is::

Expand Down Expand Up @@ -947,6 +949,11 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
This function synchronizes in order to validate array elements
when ``check_finite=True``.

See Also
--------
:obj:`dpnp.linalg.lu_solve` : Solve an equation system using
the LU factorization of `a` matrix.

Examples
--------
>>> import dpnp as np
Expand All @@ -966,6 +973,81 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite)


def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
"""
Solve a linear system, :math:`a x = b`, given the LU factorization of `a`.

For full documentation refer to :obj:`scipy.linalg.lu_solve`.

Parameters
----------
lu, piv : {tuple of dpnp.ndarrays or usm_ndarrays}
LU factorization of matrix `a` (M, M) together with pivot indices.
b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray}
Right-hand side
trans : {0, 1, 2} , optional
Type of system to solve:

===== =================
trans system
===== =================
0 :math:`a x = b`
1 :math:`a^T x = b`
2 :math:`a^H x = b`
===== =================

Default: ``0``.
overwrite_b : {None, bool}, optional
Whether to overwrite data in `b` (may increase performance).

Default: ``False``.
check_finite : {None, bool}, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.

Default: ``True``.

Returns
-------
x : {(M,), (M, K)} dpnp.ndarray
Solution to the system

Warning
-------
This function synchronizes in order to validate array elements
when ``check_finite=True``.

See Also
--------
:obj:`dpnp.linalg.lu_factor` : LU factorize a matrix.

Examples
--------
>>> import dpnp as np
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> b = np.array([1, 1, 1, 1])
>>> lu, piv = np.linalg.lu_factor(A)
>>> x = np.linalg.lu_solve((lu, piv), b)
>>> np.allclose(A @ x - b, np.zeros((4,)))
array(True)

"""

(lu, piv) = lu_and_piv
dpnp.check_supported_arrays_type(lu, piv, b)
assert_stacked_2d(lu)

return dpnp_lu_solve(
lu,
piv,
b,
trans=trans,
overwrite_b=overwrite_b,
check_finite=check_finite,
)


def matmul(x1, x2, /):
"""
Computes the matrix product.
Expand Down
128 changes: 128 additions & 0 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2477,6 +2477,134 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
return (a_h, ipiv_h)


def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
"""
dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True)

Solve an equation system (SciPy-compatible behavior).

This function mimics the behavior of `scipy.linalg.lu_solve` including
support for `trans`, `overwrite_b`, `check_finite`,
and 0-based pivot indexing.

"""

res_usm_type, exec_q = get_usm_allocations([lu, piv, b])

res_type = _common_type(lu, b)

# TODO: add broadcasting
if lu.shape[0] != b.shape[0]:
raise ValueError(
f"Shapes of lu {lu.shape} and b {b.shape} are incompatible"
)

if b.size == 0:
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type)

if lu.ndim > 2:
raise NotImplementedError("Batched matrices are not supported")

if check_finite:
if not dpnp.isfinite(lu).all():
raise ValueError(
"LU factorization array must not contain infs or NaNs.\n"
"Note that when a singular matrix is given, unlike "
"dpnp.linalg.lu_factor returns an array containing NaN."
)
if not dpnp.isfinite(b).all():
raise ValueError(
"Right-hand side array must not contain infs or NaNs"
)

lu_usm_arr = dpnp.get_usm_ndarray(lu)
piv_usm_arr = dpnp.get_usm_ndarray(piv)
b_usm_arr = dpnp.get_usm_ndarray(b)

_manager = dpu.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events

# oneMKL LAPACK getrf overwrites `lu`.
lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type)

# use DPCTL tensor function to fill the сopy of the input array
# from the input array
ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=lu_usm_arr,
dst=lu_h.get_array(),
sycl_queue=lu.sycl_queue,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, lu_copy_ev)

# oneMKL LAPACK getrf overwrites `piv`.
piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type)

# use DPCTL tensor function to fill the сopy of the pivot array
# from the pivot array
ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=piv_usm_arr,
dst=piv_h.get_array(),
sycl_queue=piv.sycl_queue,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, piv_copy_ev)

# SciPy-compatible behavior
# Copy is required if:
# - overwrite_b is False (always copy),
# - dtype mismatch,
# - not F-contiguous,
# - not writeable
if not overwrite_b or _is_copy_required(b, res_type):
b_h = dpnp.empty_like(
b, order="F", dtype=res_type, usm_type=res_usm_type
)
ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=b_usm_arr,
dst=b_h.get_array(),
sycl_queue=b.sycl_queue,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, b_copy_ev)
dep_evs = [lu_copy_ev, piv_copy_ev, b_copy_ev]
else:
# input is suitable for in-place modification
b_h = b
dep_evs = [lu_copy_ev, piv_copy_ev]

# MKL lapack uses 1-origin while SciPy uses 0-origin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess SciPy also uses MKL to call getrs, so it seems unclear for me.

piv_h += 1

if not isinstance(trans, int):
raise TypeError("`trans` must be an integer")

# Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums
if trans == 0:
trans_mkl = li.Transpose.N
elif trans == 1:
trans_mkl = li.Transpose.T
elif trans == 2:
trans_mkl = li.Transpose.C
else:
raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)")

# Call the LAPACK extension function _getrs
# to solve the system of linear equations with an LU-factored
# coefficient square matrix, with multiple right-hand sides.
ht_ev, getrs_ev = li._getrs(
exec_q,
lu_h.get_array(),
piv_h.get_array(),
b_h.get_array(),
trans_mkl,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, getrs_ev)

return b_h


def dpnp_matrix_power(a, n):
"""
dpnp_matrix_power(a, n)
Expand Down
9 changes: 9 additions & 0 deletions dpnp/tests/helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.util
from sys import platform

import dpctl
Expand Down Expand Up @@ -488,6 +489,14 @@ def is_ptl(device=None):
return _get_dev_mask(device) in (0xB000, 0xFD00)


def is_scipy_available():
"""
Return True if SciPy is installed and can be found,
False otherwise.
"""
return importlib.util.find_spec("scipy") is not None


def is_tgllp_iris_xe(device=None):
"""
Return True if a test is running on Tiger Lake-LP with Iris Xe GPU device,
Expand Down
Loading
Loading