@@ -2477,6 +2477,121 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
24772477 return (a_h , ipiv_h )
24782478
24792479
2480+ def dpnp_lu_solve (lu , piv , b , trans = 0 , overwrite_b = False , check_finite = True ):
2481+ """
2482+ dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True)
2483+
2484+ Solve an equation system (SciPy-compatible behavior).
2485+
2486+ This function mimics the behavior of `scipy.linalg.lu_solve` including
2487+ support for `trans`, `overwrite_b`, `check_finite`,
2488+ and 0-based pivot indexing.
2489+
2490+ """
2491+
2492+ res_usm_type , exec_q = get_usm_allocations ([lu , piv , b ])
2493+
2494+ res_type = _common_type (lu , b )
2495+
2496+ # TODO: add broadcasting
2497+ if lu .shape [0 ] != b .shape [0 ]:
2498+ raise ValueError (
2499+ f"Shapes of lu { lu .shape } and b { b .shape } are incompatible"
2500+ )
2501+
2502+ if b .size == 0 :
2503+ return dpnp .empty_like (b , dtype = res_type , usm_type = res_usm_type )
2504+
2505+ if lu .ndim > 2 :
2506+ raise NotImplementedError ("Batched matrices are not supported" )
2507+
2508+ if check_finite :
2509+ if not dpnp .isfinite (lu ).all ():
2510+ raise ValueError (
2511+ "LU factorization array must not contain infs or NaNs.\n "
2512+ "Note that when a singular matrix is given, unlike "
2513+ "dpnp.linalg.lu_factor returns an array containing NaN."
2514+ )
2515+ if not dpnp .isfinite (b ).all ():
2516+ raise ValueError (
2517+ "Right-hand side array must not contain infs or NaNs"
2518+ )
2519+
2520+ lu_usm_arr = dpnp .get_usm_ndarray (lu )
2521+ b_usm_arr = dpnp .get_usm_ndarray (b )
2522+
2523+ # dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy,
2524+ # convert to 1-based for oneMKL getrs
2525+ piv_h = piv + 1
2526+
2527+ _manager = dpu .SequentialOrderManager [exec_q ]
2528+ dep_evs = _manager .submitted_events
2529+
2530+ # oneMKL LAPACK getrs overwrites `lu`.
2531+ lu_h = dpnp .empty_like (lu , order = "F" , dtype = res_type , usm_type = res_usm_type )
2532+
2533+ # use DPCTL tensor function to fill the сopy of the input array
2534+ # from the input array
2535+ ht_ev , lu_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2536+ src = lu_usm_arr ,
2537+ dst = lu_h .get_array (),
2538+ sycl_queue = lu .sycl_queue ,
2539+ depends = dep_evs ,
2540+ )
2541+ _manager .add_event_pair (ht_ev , lu_copy_ev )
2542+
2543+ # SciPy-compatible behavior
2544+ # Copy is required if:
2545+ # - overwrite_b is False (always copy),
2546+ # - dtype mismatch,
2547+ # - not F-contiguous,
2548+ # - not writeable
2549+ if not overwrite_b or _is_copy_required (b , res_type ):
2550+ b_h = dpnp .empty_like (
2551+ b , order = "F" , dtype = res_type , usm_type = res_usm_type
2552+ )
2553+ ht_ev , b_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2554+ src = b_usm_arr ,
2555+ dst = b_h .get_array (),
2556+ sycl_queue = b .sycl_queue ,
2557+ depends = dep_evs ,
2558+ )
2559+ _manager .add_event_pair (ht_ev , b_copy_ev )
2560+ dep_evs = [lu_copy_ev , b_copy_ev ]
2561+ else :
2562+ # input is suitable for in-place modification
2563+ b_h = b
2564+ dep_evs = [lu_copy_ev ]
2565+
2566+ if not isinstance (trans , int ):
2567+ raise TypeError ("`trans` must be an integer" )
2568+
2569+ # Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums
2570+ if trans == 0 :
2571+ trans_mkl = li .Transpose .N
2572+ elif trans == 1 :
2573+ trans_mkl = li .Transpose .T
2574+ elif trans == 2 :
2575+ trans_mkl = li .Transpose .C
2576+ else :
2577+ raise ValueError ("`trans` must be 0 (N), 1 (T), or 2 (C)" )
2578+
2579+ # Call the LAPACK extension function _getrs
2580+ # to solve the system of linear equations with an LU-factored
2581+ # coefficient square matrix, with multiple right-hand sides.
2582+ ht_ev , getrs_ev = li ._getrs (
2583+ exec_q ,
2584+ lu_h .get_array (),
2585+ piv_h .get_array (),
2586+ b_h .get_array (),
2587+ trans_mkl ,
2588+ depends = dep_evs ,
2589+ )
2590+ _manager .add_event_pair (ht_ev , getrs_ev )
2591+
2592+ return b_h
2593+
2594+
24802595def dpnp_matrix_power (a , n ):
24812596 """
24822597 dpnp_matrix_power(a, n)
0 commit comments