@@ -192,30 +192,43 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
192192
193193 commit! (cmdbuf)
194194
195+ wait_completed (cmdbuf)
196+
195197 return B
196198end
197199
198200
201+ function LinearAlgebra.:(\ )(A:: LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
202+ C = deepcopy (B)
203+ LinearAlgebra. ldiv! (A, C)
204+ return C
205+ end
206+
207+
199208function LinearAlgebra. ldiv! (A:: LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
200209 M,N = size (B,1 ), size (B,2 )
201210 dev = current_device ()
202211 queue = global_queue (dev)
203212
204- Bt = reshape (B, (N,M))
213+ At = similar (A. factors)
214+ Bt = similar (B, (N,M))
205215 P = reshape ((A. ipiv .- UInt32 (1 )), (1 ,M))
206- X = similar (B)
216+ X = similar (B, (N,M))
217+
218+ transpose! (At, A. factors)
219+ transpose! (Bt, B)
207220
208- mps_a = MPSMatrix (A . factors )
221+ mps_a = MPSMatrix (At )
209222 mps_b = MPSMatrix (Bt)
210223 mps_p = MPSMatrix (P)
211224 mps_x = MPSMatrix (X)
212225
213226 MTLCommandBuffer (queue) do cmdbuf
214- kernel = MPSMatrixSolveLU (dev, true , M, N)
227+ kernel = MPSMatrixSolveLU (dev, false , M, N)
215228 encode! (cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
216229 end
217230
218- Bt . = X
231+ transpose! (B, X)
219232 return B
220233end
221234
@@ -225,20 +238,24 @@ function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrM
225238 dev = current_device ()
226239 queue = global_queue (dev)
227240
228- Ad = MtlMatrix (A; storage= Private)
229- Bt = reshape (B, (N,M))
230- X = similar (B)
241+ Ad = MtlMatrix (A' )
242+ Br = similar (B, (M,M))
243+ X = similar (Br)
244+
245+ transpose! (Br, B)
231246
232247 mps_a = MPSMatrix (Ad)
233- mps_b = MPSMatrix (Bt )
248+ mps_b = MPSMatrix (Br )
234249 mps_x = MPSMatrix (X)
235250
236- MTLCommandBuffer (queue) do cmdbuf
237- kernel = MPSMatrixSolveTriangular (dev, false , false , false , false , M, N , 1.0 )
251+ buf = MTLCommandBuffer (queue) do cmdbuf
252+ kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , N, M , 1.0 )
238253 encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
239254 end
240255
241- Bt .= X
256+ wait_completed (buf)
257+
258+ copy! (B, X)
242259 return B
243260end
244261
@@ -248,20 +265,23 @@ function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVe
248265 dev = current_device ()
249266 queue = global_queue (dev)
250267
251- Ad = MtlMatrix (A; storage = Private )
252- Bt = reshape (B, (N,M ))
253- X = similar (B )
268+ Ad = MtlMatrix (A)
269+ Br = reshape (B, (M,N ))
270+ X = similar (Br )
254271
255272 mps_a = MPSMatrix (Ad)
256- mps_b = MPSMatrix (Bt )
273+ mps_b = MPSMatrix (Br )
257274 mps_x = MPSMatrix (X)
258275
259- MTLCommandBuffer (queue) do cmdbuf
260- kernel = MPSMatrixSolveTriangular (dev, false , false , false , true , M, N, 1.0 )
276+
277+ buf = MTLCommandBuffer (queue) do cmdbuf
278+ kernel = MPSMatrixSolveTriangular (dev, true , false , false , true , M, N, 1.0 )
261279 encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
262280 end
263281
264- Bt .= X
282+ wait_completed (buf)
283+
284+ copy! (Br, X)
265285 return B
266286end
267287
@@ -271,20 +291,23 @@ function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrM
271291 dev = current_device ()
272292 queue = global_queue (dev)
273293
274- Ad = MtlMatrix (A; storage = Private )
275- Bt = reshape (B, (N,M ))
276- X = similar (B )
294+ Ad = MtlMatrix (A)
295+ Br = reshape (B, (M,N ))
296+ X = similar (Br )
277297
278298 mps_a = MPSMatrix (Ad)
279- mps_b = MPSMatrix (Bt )
299+ mps_b = MPSMatrix (Br )
280300 mps_x = MPSMatrix (X)
281301
282- MTLCommandBuffer (queue) do cmdbuf
283- kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , M, N, 1.0 )
302+
303+ buf = MTLCommandBuffer (queue) do cmdbuf
304+ kernel = MPSMatrixSolveTriangular (dev, true , true , false , false , M, N, 1.0 )
284305 encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
285306 end
286307
287- Bt .= X
308+ wait_completed (buf)
309+
310+ copy! (Br, X)
288311 return B
289312end
290313
@@ -294,19 +317,22 @@ function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVe
294317 dev = current_device ()
295318 queue = global_queue (dev)
296319
297- A = MtlMatrix (A; storage = Private )
298- Bt = reshape (B, (N,M ))
299- X = similar (B )
320+ Ad = MtlMatrix (A)
321+ Br = reshape (B, (M,N ))
322+ X = similar (Br )
300323
301- mps_a = MPSMatrix (A )
302- mps_b = MPSMatrix (Bt )
324+ mps_a = MPSMatrix (Ad )
325+ mps_b = MPSMatrix (Br )
303326 mps_x = MPSMatrix (X)
304327
305- MTLCommandBuffer (queue) do cmdbuf
306- kernel = MPSMatrixSolveTriangular (dev, false , true , false , true , M, N, 1.0 )
328+
329+ buf = MTLCommandBuffer (queue) do cmdbuf
330+ kernel = MPSMatrixSolveTriangular (dev, true , true , false , true , M, N, 1.0 )
307331 encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
308332 end
309333
310- Bt .= X
334+ wait_completed (buf)
335+
336+ copy! (Br, X)
311337 return B
312338end
0 commit comments