diff --git a/src/extensions/lapack/lapack-templates.lisp b/src/extensions/lapack/lapack-templates.lisp index 698d828..109192c 100644 --- a/src/extensions/lapack/lapack-templates.lisp +++ b/src/extensions/lapack/lapack-templates.lisp @@ -117,7 +117,7 @@ lda ipiv info) - (values a-tensor ipiv-tensor)))) + (values (magicl::from-storage a (list m n) :layout :column-major) ipiv-tensor)))) (defun generate-lapack-lu-solve-for-type (class type lu-solve-function) (declare (ignore type)) @@ -530,10 +530,10 @@ (setf work (make-array (max 1 lwork) :element-type ',type)) ;; run it again with optimal workspace size (,qr-function rows cols a lda tau work lwork info) - (values a-tensor - (from-array tau (list (min rows cols)) - :type ',type - :input-layout :column-major))))) + (values (magicl::from-storage a (list rows cols) + :layout :column-major) + (magicl::from-storage tau (list (min rows cols)) + :layout :column-major))))) (defmethod lapack-ql ((m ,class)) (let* ((rows (nrows m)) @@ -553,10 +553,10 @@ (setf work (make-array (max 1 lwork) :element-type ',type)) ;; run it again with optimal workspace size (,ql-function rows cols a lda tau work lwork info) - (values a-tensor - (from-array tau (list (min rows cols)) - :type ',type - :input-layout :column-major))))) + (values (magicl::from-storage a (list rows cols) + :layout :column-major) + (magicl::from-storage tau (list (min rows cols)) + :layout :column-major))))) (defmethod lapack-rq ((m ,class)) (let* ((rows (nrows m)) @@ -576,10 +576,10 @@ (setf work (make-array (max 1 lwork) :element-type ',type)) ;; run it again with optimal workspace size (,rq-function rows cols a lda tau work lwork info) - (values a-tensor - (from-array tau (list (min rows cols)) - :type ',type - :input-layout :column-major))))) + (values (magicl::from-storage a (list rows cols) + :layout :column-major) + (magicl::from-storage tau (list (min rows cols)) + :layout :column-major))))) (defmethod lapack-lq ((m ,class)) (let* ((rows (nrows m)) @@ -599,10 +599,10 @@ (setf work (make-array (max 1 lwork) :element-type ',type)) ;; run it again with optimal workspace size (,lq-function rows cols a lda tau work lwork info) - (values a-tensor - (from-array tau (list (min rows cols)) - :type ',type - :input-layout :column-major))))) + (values (magicl::from-storage a (list rows cols) + :layout :column-major) + (magicl::from-storage tau (list (min rows cols)) + :layout :column-major))))) (defmethod lapack-qr-q ((m ,class) tau) (let ((m (nrows m)) diff --git a/src/high-level/matrix.lisp b/src/high-level/matrix.lisp index c1c9692..807040a 100644 --- a/src/high-level/matrix.lisp +++ b/src/high-level/matrix.lisp @@ -418,29 +418,6 @@ In the world of BLAS/LAPACK, this is known as GEMM. (empty (list (* ma mb) (* na nb)) :type '(complex double-float)))))))) -(define-extensible-function (transpose! transpose!-lisp) (matrix &key fast) - (:documentation "Transpose MATRIX, replacing the elements of MATRIX, optionally performing a faster change of layout if FAST is specified") - (:method ((matrix matrix) &key fast) - "Transpose a matrix by copying values. -If FAST is t then just change layout. Fast can cause problems when you want to multiply specifying transpose." - (if fast - (progn (rotatef (matrix-ncols matrix) (matrix-nrows matrix)) - (setf (matrix-layout matrix) (ecase (matrix-layout matrix) - (:row-major :column-major) - (:column-major :row-major)))) - (let ((index-function - (ecase (matrix-layout matrix) - (:row-major #'matrix-row-major-index) - (:column-major #'matrix-column-major-index))) - (shape (shape matrix))) - (loop :for row :below (matrix-nrows matrix) - :do (loop :for col :from row :below (matrix-ncols matrix) - :do (rotatef - (aref (storage matrix) (apply index-function row col shape)) - (aref (storage matrix) (apply index-function col row shape))))) - (rotatef (matrix-ncols matrix) (matrix-nrows matrix)))) - matrix)) - (define-extensible-function (transpose transpose-lisp) (matrix) (:documentation "Create a new matrix containing the transpose of MATRIX") (:method ((matrix matrix)) @@ -461,6 +438,21 @@ If fast is t then just change layout. Fast can cause problems when you want to m (setf (aref (storage new-matrix) index2) (aref (storage matrix) index1)))))) new-matrix))) +(define-extensible-function (transpose! transpose!-lisp) (matrix &key fast) + (:documentation "Transpose MATRIX, replacing the elements of MATRIX, optionally performing a faster change of layout if FAST is specified") + (:method ((matrix matrix) &key fast) + "Transpose a matrix by copying values. +If FAST is t then just change layout. Fast can cause problems when you want to multiply specifying transpose." + (if fast + (progn (rotatef (matrix-ncols matrix) (matrix-nrows matrix)) + (setf (matrix-layout matrix) (ecase (matrix-layout matrix) + (:row-major :column-major) + (:column-major :row-major)))) + ;; TODO: make it really inplace + (let ((transposed (transpose matrix))) + (replace (storage matrix) (storage transposed)))) + matrix)) + ;; TODO: allow setf on matrix diag (define-extensible-function (diag diag-lisp) (matrix) (:documentation "Get a list of the diagonal elements of MATRIX") diff --git a/tests/high-level-tests.lisp b/tests/high-level-tests.lisp index abb29d9..7f628cb 100644 --- a/tests/high-level-tests.lisp +++ b/tests/high-level-tests.lisp @@ -306,3 +306,44 @@ (b (magicl:from-list '((0d0 0d0) (1d0 1d0) (5d0 2d0) (6d0 3d0)) '(4 2)))) (let ((x (magicl:least-squares A b))) (is (magicl:= x (magicl:from-list '((2.2d0 1d0) (-2.5d0 -1d0)) '(2 2))))))) + +(deftest data-layout-invariant () + "Check that QR/QL/LU/etc. is invariant to row- or column-major order" + (loop repeat 50 + for row-major = (magicl:rand '(100 50) :layout :row-major) + for column-major = (magicl:@ row-major (magicl:eye '(50 50) :layout :column-major)) + for rms = (magicl:rand '(50 50) :layout :row-major) + for cms = (magicl:@ rms (magicl:eye '(50 50) :layout :column-major)) + do + ;; QR + (is (equalp + (multiple-value-list + (magicl:qr row-major)) + (multiple-value-list + (magicl:qr column-major)))) + ;; QL + (is (equalp + (multiple-value-list + (magicl:ql row-major)) + (multiple-value-list + (magicl:ql column-major)))) + ;; LU + (is (equalp + (multiple-value-list + (magicl:lu row-major)) + (multiple-value-list + (magicl:lu column-major)))) + ;; Solver + (flet ((solve (m) + (multiple-value-bind (lu ipiv) + (magicl:lu m) + (magicl:lu-solve lu ipiv (magicl:ones '(100 1)))))) + (is (equalp (solve rms) (solve cms)))) + ;; Inversion + (is (equalp (magicl:inv rms) (magicl:inv cms))) + ;; SVD + (is (equalp + (multiple-value-list + (magicl:svd row-major)) + (multiple-value-list + (magicl:svd column-major))))))