Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions src/extensions/lapack/lapack-templates.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
lda
ipiv
info)
(values a-tensor ipiv-tensor))))
(values (magicl::from-storage a (list m n) :layout :column-major) ipiv-tensor))))

(defun generate-lapack-lu-solve-for-type (class type lu-solve-function)
(declare (ignore type))
Expand Down Expand Up @@ -530,10 +530,10 @@
(setf work (make-array (max 1 lwork) :element-type ',type))
;; run it again with optimal workspace size
(,qr-function rows cols a lda tau work lwork info)
(values a-tensor
(from-array tau (list (min rows cols))
:type ',type
:input-layout :column-major)))))
(values (magicl::from-storage a (list rows cols)
:layout :column-major)
(magicl::from-storage tau (list (min rows cols))
:layout :column-major)))))

(defmethod lapack-ql ((m ,class))
(let* ((rows (nrows m))
Expand All @@ -553,10 +553,10 @@
(setf work (make-array (max 1 lwork) :element-type ',type))
;; run it again with optimal workspace size
(,ql-function rows cols a lda tau work lwork info)
(values a-tensor
(from-array tau (list (min rows cols))
:type ',type
:input-layout :column-major)))))
(values (magicl::from-storage a (list rows cols)
:layout :column-major)
(magicl::from-storage tau (list (min rows cols))
:layout :column-major)))))

(defmethod lapack-rq ((m ,class))
(let* ((rows (nrows m))
Expand All @@ -576,10 +576,10 @@
(setf work (make-array (max 1 lwork) :element-type ',type))
;; run it again with optimal workspace size
(,rq-function rows cols a lda tau work lwork info)
(values a-tensor
(from-array tau (list (min rows cols))
:type ',type
:input-layout :column-major)))))
(values (magicl::from-storage a (list rows cols)
:layout :column-major)
(magicl::from-storage tau (list (min rows cols))
:layout :column-major)))))

(defmethod lapack-lq ((m ,class))
(let* ((rows (nrows m))
Expand All @@ -599,10 +599,10 @@
(setf work (make-array (max 1 lwork) :element-type ',type))
;; run it again with optimal workspace size
(,lq-function rows cols a lda tau work lwork info)
(values a-tensor
(from-array tau (list (min rows cols))
:type ',type
:input-layout :column-major)))))
(values (magicl::from-storage a (list rows cols)
:layout :column-major)
(magicl::from-storage tau (list (min rows cols))
:layout :column-major)))))

(defmethod lapack-qr-q ((m ,class) tau)
(let ((m (nrows m))
Expand Down
38 changes: 15 additions & 23 deletions src/high-level/matrix.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -418,29 +418,6 @@ In the world of BLAS/LAPACK, this is known as GEMM.
(empty (list (* ma mb) (* na nb))
:type '(complex double-float))))))))

(define-extensible-function (transpose! transpose!-lisp) (matrix &key fast)
(:documentation "Transpose MATRIX, replacing the elements of MATRIX, optionally performing a faster change of layout if FAST is specified")
(:method ((matrix matrix) &key fast)
"Transpose a matrix by copying values.
If FAST is t then just change layout. Fast can cause problems when you want to multiply specifying transpose."
(if fast
(progn (rotatef (matrix-ncols matrix) (matrix-nrows matrix))
(setf (matrix-layout matrix) (ecase (matrix-layout matrix)
(:row-major :column-major)
(:column-major :row-major))))
(let ((index-function
(ecase (matrix-layout matrix)
(:row-major #'matrix-row-major-index)
(:column-major #'matrix-column-major-index)))
(shape (shape matrix)))
(loop :for row :below (matrix-nrows matrix)
:do (loop :for col :from row :below (matrix-ncols matrix)
:do (rotatef
(aref (storage matrix) (apply index-function row col shape))
(aref (storage matrix) (apply index-function col row shape)))))
(rotatef (matrix-ncols matrix) (matrix-nrows matrix))))
matrix))

(define-extensible-function (transpose transpose-lisp) (matrix)
(:documentation "Create a new matrix containing the transpose of MATRIX")
(:method ((matrix matrix))
Expand All @@ -461,6 +438,21 @@ If fast is t then just change layout. Fast can cause problems when you want to m
(setf (aref (storage new-matrix) index2) (aref (storage matrix) index1))))))
new-matrix)))

(define-extensible-function (transpose! transpose!-lisp) (matrix &key fast)
(:documentation "Transpose MATRIX, replacing the elements of MATRIX, optionally performing a faster change of layout if FAST is specified")
(:method ((matrix matrix) &key fast)
"Transpose a matrix by copying values.
If FAST is t then just change layout. Fast can cause problems when you want to multiply specifying transpose."
(if fast
(progn (rotatef (matrix-ncols matrix) (matrix-nrows matrix))
(setf (matrix-layout matrix) (ecase (matrix-layout matrix)
(:row-major :column-major)
(:column-major :row-major))))
;; TODO: make it really inplace
(let ((transposed (transpose matrix)))
(replace (storage matrix) (storage transposed))))
matrix))

;; TODO: allow setf on matrix diag
(define-extensible-function (diag diag-lisp) (matrix)
(:documentation "Get a list of the diagonal elements of MATRIX")
Expand Down
41 changes: 41 additions & 0 deletions tests/high-level-tests.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,44 @@
(b (magicl:from-list '((0d0 0d0) (1d0 1d0) (5d0 2d0) (6d0 3d0)) '(4 2))))
(let ((x (magicl:least-squares A b)))
(is (magicl:= x (magicl:from-list '((2.2d0 1d0) (-2.5d0 -1d0)) '(2 2)))))))

(deftest data-layout-invariant ()
"Check that QR/QL/LU/etc. is invariant to row- or column-major order"
(loop repeat 50
for row-major = (magicl:rand '(100 50) :layout :row-major)
for column-major = (magicl:@ row-major (magicl:eye '(50 50) :layout :column-major))
for rms = (magicl:rand '(50 50) :layout :row-major)
for cms = (magicl:@ rms (magicl:eye '(50 50) :layout :column-major))
do
;; QR
(is (equalp
(multiple-value-list
(magicl:qr row-major))
(multiple-value-list
(magicl:qr column-major))))
;; QL
(is (equalp
(multiple-value-list
(magicl:ql row-major))
(multiple-value-list
(magicl:ql column-major))))
;; LU
(is (equalp
(multiple-value-list
(magicl:lu row-major))
(multiple-value-list
(magicl:lu column-major))))
;; Solver
(flet ((solve (m)
(multiple-value-bind (lu ipiv)
(magicl:lu m)
(magicl:lu-solve lu ipiv (magicl:ones '(100 1))))))
(is (equalp (solve rms) (solve cms))))
;; Inversion
(is (equalp (magicl:inv rms) (magicl:inv cms)))
;; SVD
(is (equalp
(multiple-value-list
(magicl:svd row-major))
(multiple-value-list
(magicl:svd column-major))))))