|
1 | 1 | using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
|
2 | 2 | using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
|
3 |
| -using MatrixAlgebraKit: svd_compact, svd_full |
| 3 | +using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc, truncrank |
4 | 4 | using LinearAlgebra: LinearAlgebra
|
5 | 5 | using Random: Random
|
6 | 6 | using Test: @inferred, @testset, @test
|
|
83 | 83 | usv = svd_full(c)
|
84 | 84 | @test test_svd(c, usv; full=true)
|
85 | 85 | end
|
| 86 | + |
| 87 | +# svd_trunc! |
| 88 | +# ---------- |
| 89 | + |
| 90 | +@testset "svd_trunc ($m, $n) BlockSparseMatri{$T}" for ((m, n), T) in test_params |
| 91 | + (m, n), T = first(test_params) |
| 92 | + a = BlockSparseArray{T}(undef, m, n) |
| 93 | + |
| 94 | + # test blockdiagonal |
| 95 | + for i in LinearAlgebra.diagind(blocks(a)) |
| 96 | + I = CartesianIndices(blocks(a))[i] |
| 97 | + a[Block(I.I...)] = rand(T, size(blocks(a)[i])) |
| 98 | + end |
| 99 | + |
| 100 | + minmn = min(size(a)...) |
| 101 | + r = max(1, minmn - 2) |
| 102 | + |
| 103 | + U1, S1, V1ᴴ = svd_trunc(a; trunc=truncrank(r)) |
| 104 | + U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc=truncrank(r)) |
| 105 | + @test size(U1) == size(U2) |
| 106 | + @test size(S1) == size(S2) |
| 107 | + @test size(V1ᴴ) == size(V2ᴴ) |
| 108 | + @test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ |
| 109 | + |
| 110 | + @test (U1' * U1 ≈ LinearAlgebra.I) |
| 111 | + @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I) |
| 112 | + |
| 113 | + # test permuted blockdiagonal |
| 114 | + perm = Random.randperm(length(m)) |
| 115 | + b = a[Block.(perm), Block.(1:length(n))] |
| 116 | + U1, S1, V1ᴴ = svd_trunc(b; trunc=truncrank(r)) |
| 117 | + U2, S2, V2ᴴ = svd_trunc(Matrix(b); trunc=truncrank(r)) |
| 118 | + @test size(U1) == size(U2) |
| 119 | + @test size(S1) == size(S2) |
| 120 | + @test size(V1ᴴ) == size(V2ᴴ) |
| 121 | + @test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ |
| 122 | + |
| 123 | + @test (U1' * U1 ≈ LinearAlgebra.I) |
| 124 | + @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I) |
| 125 | + |
| 126 | + # test permuted blockdiagonal with missing row/col |
| 127 | + I_removed = rand(eachblockstoredindex(b)) |
| 128 | + c = copy(b) |
| 129 | + delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed)))) |
| 130 | + U1, S1, V1ᴴ = svd_trunc(c; trunc=truncrank(r)) |
| 131 | + U2, S2, V2ᴴ = svd_trunc(Matrix(c); trunc=truncrank(r)) |
| 132 | + @test size(U1) == size(U2) |
| 133 | + @test size(S1) == size(S2) |
| 134 | + @test size(V1ᴴ) == size(V2ᴴ) |
| 135 | + @test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ |
| 136 | + |
| 137 | + @test (U1' * U1 ≈ LinearAlgebra.I) |
| 138 | + @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I) |
| 139 | +end |
0 commit comments