diff --git a/Project.toml b/Project.toml index 6736baa5..aa30c584 100644 --- a/Project.toml +++ b/Project.toml @@ -1,21 +1,24 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" authors = ["Jutho and contributors"] -version = "0.2.2" +version = "0.2.3" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" [extensions] MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore" +MatrixAlgebraKitFillArraysExt = "FillArrays" [compat] Aqua = "0.6, 0.7, 0.8" ChainRulesCore = "1" ChainRulesTestUtils = "1" +FillArrays = "1" JET = "0.9" LinearAlgebra = "1" SafeTestsets = "0.1" @@ -36,5 +39,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", - "ChainRulesTestUtils", "StableRNGs", "Zygote"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"] diff --git a/ext/MatrixAlgebraKitFillArraysExt.jl b/ext/MatrixAlgebraKitFillArraysExt.jl new file mode 100644 index 00000000..8d09c012 --- /dev/null +++ b/ext/MatrixAlgebraKitFillArraysExt.jl @@ -0,0 +1,401 @@ +module MatrixAlgebraKitFillArraysExt + +using LinearAlgebra +using MatrixAlgebraKit +using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, TruncationStrategy, + check_input, diagview, + findtruncated, select_algorithm, select_truncation +using FillArrays +using FillArrays: AbstractZerosMatrix, OnesVector, RectDiagonal, SquareEye + +function MatrixAlgebraKit.diagview(A::RectDiagonal{<:Any,<:OnesVector}) + return A.diag +end + +struct ZerosAlgorithm <: AbstractAlgorithm end + +for f in [:eig, :eigh, :lq, :polar, :qr, :svd] + ff = Symbol("default_", f, "_algorithm") + @eval begin + function MatrixAlgebraKit.$ff(::Type{<:AbstractZerosMatrix}; kwargs...) + return ZerosAlgorithm() + end + end +end + +for f in [:eig_full, + :eigh_full, + :eig_vals, + :eigh_vals, + :qr_compact, + :qr_full, + :left_polar, + :lq_compact, + :lq_full, + :right_polar, + :svd_compact, + :svd_full, + :svd_vals] + f! = Symbol(f, "!") + @eval begin + MatrixAlgebraKit.copy_input(::typeof($f), A::AbstractZerosMatrix) = A + function MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractZerosMatrix, + alg::ZerosAlgorithm) + return nothing + end + end +end + +for f in [:eig_full!, :eigh_full!] + @eval begin + function MatrixAlgebraKit.check_input(::typeof($f), A::AbstractZerosMatrix, F) + LinearAlgebra.checksquare(A) + return nothing + end + function MatrixAlgebraKit.$f(A::AbstractZerosMatrix, F, alg::ZerosAlgorithm; + kwargs...) + check_input($f, A, F) + return (A, Eye(axes(A))) + end + end +end + +for f in [:eig_trunc!, :eigh_trunc!] + @eval begin + # TODO: Delete this when `select_algorithm` is generalized. + function MatrixAlgebraKit.select_algorithm(::typeof($f), ::Type{A}, alg; + trunc=nothing, + kwargs...) where {A<:Zeros} + alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) + end + # TODO: I think it would be better to dispatch on the algorithm here, + # rather than the output types. + function MatrixAlgebraKit.truncate!(::typeof($f), (D, V)::Tuple{Zeros,Eye}, + strategy::TruncationStrategy) + ind = findtruncated(diagview(D), strategy) + D′ = D[ind, ind] + V′ = Eye((axes(V, 1), only(axes(axes(V, 2)[ind])))) + return (D′, V′) + end + end +end + +for f in [:eig_vals!, :eigh_vals!] + @eval begin + function MatrixAlgebraKit.check_input(::typeof($f), A::AbstractZerosMatrix, F) + LinearAlgebra.checksquare(A) + return nothing + end + function MatrixAlgebraKit.$f(A::AbstractZerosMatrix, F, alg::ZerosAlgorithm; + kwargs...) + check_input($f, A, F) + return diagview(A) + end + end +end + +function MatrixAlgebraKit.qr_compact!(A::AbstractZerosMatrix, F, alg::ZerosAlgorithm) + m, n = size(A) + ax = axes(A) + if m > n + r_ax = (ax[2], ax[2]) + return (Eye(ax), Zeros(r_ax)) + else + q_ax = (ax[1], ax[1]) + return (Eye(q_ax), Zeros(ax)) + end +end + +function MatrixAlgebraKit.qr_full!(A::AbstractZerosMatrix, F, alg::ZerosAlgorithm) + ax = axes(A) + q_ax = (ax[1], ax[1]) + return (Eye(q_ax), Zeros(ax)) +end + +function MatrixAlgebraKit.lq_compact!(A::AbstractZerosMatrix, F, alg::ZerosAlgorithm) + m, n = size(A) + ax = axes(A) + if m < n + l_ax = (ax[1], ax[1]) + return (Zeros(l_ax), Eye(ax)) + else + q_ax = (ax[2], ax[2]) + return (Zeros(ax), Eye(q_ax)) + end +end + +function MatrixAlgebraKit.lq_full!(A::AbstractZerosMatrix, F, alg::ZerosAlgorithm) + ax = axes(A) + q_ax = (ax[2], ax[2]) + return (Zeros(ax), Eye(q_ax)) +end + +function MatrixAlgebraKit.svd_compact!(A::AbstractZerosMatrix, F, alg::ZerosAlgorithm) + m, n = size(A) + ax = axes(A) + if m > n + s_ax = (ax[2], ax[2]) + return (Eye(ax), Zeros(s_ax), Eye(s_ax)) + else + s_ax = (ax[1], ax[1]) + return (Eye(s_ax), Zeros(s_ax), Eye(ax)) + end +end + +function MatrixAlgebraKit.svd_full!(A::AbstractZerosMatrix, F, alg::ZerosAlgorithm) + ax = axes(A) + return (Eye((ax[1], ax[1])), Zeros(ax), Eye((ax[2], ax[2]))) +end + +# TODO: Delete this when `select_algorithm` is generalized. +function MatrixAlgebraKit.select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg; + trunc=nothing, + kwargs...) where {A<:Zeros} + alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) +end +# TODO: I think it would be better to dispatch on the algorithm here, +# rather than the output types. +function MatrixAlgebraKit.truncate!(::typeof(svd_trunc!), (U, S, V)::Tuple{Eye,Zeros,Eye}, + strategy::TruncationStrategy) + ind = findtruncated(diagview(S), strategy) + U′ = Eye((axes(U, 1), only(axes(axes(U, 2)[ind])))) + S′ = S[ind, ind] + V′ = Eye((only(axes(axes(V, 1)[ind])), axes(V, 2))) + return (U′, S′, V′) +end + +function MatrixAlgebraKit.svd_vals!(A::AbstractZerosMatrix, F, alg::ZerosAlgorithm) + return diagview(A) +end + +function MatrixAlgebraKit.check_input(::typeof(left_polar!), A::AbstractZerosMatrix, F) + m, n = size(A) + m >= n || + throw(ArgumentError("input matrix needs at least as many rows as columns")) + return nothing +end +function MatrixAlgebraKit.left_polar!(A::AbstractZerosMatrix, F, alg::ZerosAlgorithm) + check_input(left_polar!, A, F) + U, S, Vᴴ = svd_compact(A) + return (Eye((axes(A, 1), axes(A, 2))), Vᴴ' * S * Vᴴ) +end + +function MatrixAlgebraKit.check_input(::typeof(right_polar!), A::AbstractZerosMatrix, F) + m, n = size(A) + n >= m || + throw(ArgumentError("input matrix needs at least as many columns as rows")) + return nothing +end +function MatrixAlgebraKit.right_polar!(A::AbstractZerosMatrix, F, alg::ZerosAlgorithm) + check_input(right_polar!, A, F) + U, S, Vᴴ = svd_compact(A) + return (U * S * U', Eye((axes(U, 1), axes(Vᴴ, 2)))) +end + +struct EyeAlgorithm <: AbstractAlgorithm end + +for f in [:eig, :eigh, :lq, :qr, :polar, :svd] + ff = Symbol("default_", f, "_algorithm") + @eval begin + function MatrixAlgebraKit.$ff(A::Type{<:Eye}; kwargs...) + return EyeAlgorithm() + end + end +end + +for f in [:eig_full, + :eigh_full, + :eig_vals, + :eigh_vals, + :qr_compact, + :qr_full, + :lq_compact, + :lq_full, + :left_polar, + :right_polar, + :svd_compact, + :svd_full, + :svd_vals] + f! = Symbol(f, "!") + @eval begin + MatrixAlgebraKit.copy_input(::typeof($f), A::Eye) = A + function MatrixAlgebraKit.initialize_output(::typeof($f!), A::Eye, + alg::EyeAlgorithm) + return nothing + end + end +end + +for f in [:eig_full!, :eigh_full!] + @eval begin + function MatrixAlgebraKit.check_input(::typeof($f), A::Eye, F) + LinearAlgebra.checksquare(A) + return nothing + end + function MatrixAlgebraKit.$f(A::Eye, F, alg::EyeAlgorithm; + kwargs...) + check_input($f, A, F) + return (A, A) + end + end +end + +for f in [:eig_trunc!, :eigh_trunc!] + @eval begin + # TODO: Delete this when `select_algorithm` is generalized. + function MatrixAlgebraKit.select_algorithm(::typeof($f), ::Type{A}, alg; + trunc=nothing, + kwargs...) where {A<:Eye} + alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) + end + # TODO: I think it would be better to dispatch on the algorithm here, + # rather than the output types. + function MatrixAlgebraKit.truncate!(::typeof($f), (D, V)::Tuple{Eye,Eye}, + strategy::TruncationStrategy) + ind = findtruncated(diagview(D), strategy) + D′ = Diagonal(diagview(D)[ind]) + U′ = Eye((axes(V, 1), only(axes(axes(V, 2)[ind])))) + return (D′, U′) + end + end +end + +for f in [:eig_vals!, :eigh_vals!] + @eval begin + function MatrixAlgebraKit.check_input(::typeof($f), A::Eye, F) + LinearAlgebra.checksquare(A) + return nothing + end + function MatrixAlgebraKit.$f(A::Eye, F, alg::EyeAlgorithm; + kwargs...) + check_input($f, A, F) + return diagview(A) + end + end +end + +function MatrixAlgebraKit.qr_compact!(A::Eye, F, alg::EyeAlgorithm) + m, n = size(A) + ax = axes(A) + if m > n + r_ax = (ax[2], ax[2]) + return (Eye(ax), Eye(r_ax)) + else + q_ax = (ax[1], ax[1]) + return (Eye(q_ax), Eye(ax)) + end +end +function MatrixAlgebraKit.qr_compact!(A::SquareEye, F, alg::EyeAlgorithm) + return (A, A) +end + +function MatrixAlgebraKit.qr_full!(A::Eye, F, alg::EyeAlgorithm) + ax = axes(A) + q_ax = (ax[1], ax[1]) + return (Eye(q_ax), A) +end +function MatrixAlgebraKit.qr_full!(A::SquareEye, F, alg::EyeAlgorithm) + return (A, A) +end + +function MatrixAlgebraKit.lq_compact!(A::Eye, F, alg::EyeAlgorithm) + m, n = size(A) + ax = axes(A) + if m < n + l_ax = (ax[1], ax[1]) + return (Eye(l_ax), Eye(ax)) + else + q_ax = (ax[2], ax[2]) + return (Eye(ax), Eye(q_ax)) + end +end +function MatrixAlgebraKit.lq_compact!(A::SquareEye, F, alg::EyeAlgorithm) + return (A, A) +end + +function MatrixAlgebraKit.lq_full!(A::Eye, F, alg::EyeAlgorithm) + ax = axes(A) + q_ax = (ax[2], ax[2]) + return (A, Eye(q_ax)) +end +function MatrixAlgebraKit.lq_full!(A::SquareEye, F, alg::EyeAlgorithm) + return (A, A) +end + +function MatrixAlgebraKit.svd_compact!(A::Eye, F, alg::EyeAlgorithm) + m, n = size(A) + ax = axes(A) + if m > n + s_ax = (ax[2], ax[2]) + return (Eye(ax), Eye(s_ax), Eye(s_ax)) + else + s_ax = (ax[1], ax[1]) + return (Eye(s_ax), Eye(s_ax), Eye(ax)) + end +end +function MatrixAlgebraKit.svd_compact!(A::SquareEye, F, alg::EyeAlgorithm) + return (A, A, A) +end + +function MatrixAlgebraKit.svd_full!(A::Eye, F, alg::EyeAlgorithm) + ax = axes(A) + return (Eye((ax[1],)), A, Eye((ax[2],))) +end +function MatrixAlgebraKit.svd_full!(A::SquareEye, F, alg::EyeAlgorithm) + return (A, A, A) +end + +# TODO: Delete this when `select_algorithm` is generalized. +function MatrixAlgebraKit.select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg; + trunc=nothing, + kwargs...) where {A<:Eye} + alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) +end +# TODO: I think it would be better to dispatch on the algorithm here, +# rather than the output types. +function MatrixAlgebraKit.truncate!(::typeof(svd_trunc!), (U, S, V)::Tuple{Eye,Eye,Eye}, + strategy::TruncationStrategy) + ind = findtruncated(diagview(S), strategy) + U′ = Eye((axes(U, 1), only(axes(axes(U, 2)[ind])))) + S′ = Diagonal(diagview(S)[ind]) + V′ = Eye((only(axes(axes(V, 1)[ind])), axes(V, 2))) + return (U′, S′, V′) +end + +function MatrixAlgebraKit.svd_vals!(A::Eye, F, alg::EyeAlgorithm) + return diagview(A) +end + +function MatrixAlgebraKit.check_input(::typeof(left_polar!), A::Eye, F) + m, n = size(A) + m >= n || + throw(ArgumentError("input matrix needs at least as many rows as columns")) + return nothing +end +function MatrixAlgebraKit.left_polar!(A::Eye, F, alg::EyeAlgorithm) + check_input(left_polar!, A, F) + return (Eye((axes(A, 1), axes(A, 2))), Eye((axes(A, 2), axes(A, 2)))) +end +function MatrixAlgebraKit.left_polar!(A::SquareEye, F, alg::EyeAlgorithm) + return (A, A) +end + +function MatrixAlgebraKit.check_input(::typeof(right_polar!), A::Eye, F) + m, n = size(A) + n >= m || + throw(ArgumentError("input matrix needs at least as many columns as rows")) + return nothing +end +function MatrixAlgebraKit.right_polar!(A::Eye, F, alg::EyeAlgorithm) + check_input(right_polar!, A, F) + return (Eye((axes(A, 1), axes(A, 1))), Eye((axes(A, 1), axes(A, 2)))) +end +function MatrixAlgebraKit.right_polar!(A::SquareEye, F, alg::EyeAlgorithm) + return (A, A) +end + +end diff --git a/test/fillarrays.jl b/test/fillarrays.jl new file mode 100644 index 00000000..ed278330 --- /dev/null +++ b/test/fillarrays.jl @@ -0,0 +1,362 @@ +using MatrixAlgebraKit +using LinearAlgebra +using Test +using TestExtras +using FillArrays +using FillArrays: SquareEye + +@testset "Zeros" begin + for f in [:eig_full, :eigh_full] + @eval begin + A = Zeros(3, 3) + D, V = @constinferred $f(A) + @test A * V == V * D + @test size(D) == size(A) + @test size(V) == size(A) + @test iszero(D) + @test D isa Zeros + @test V == I + @test V isa Eye + end + end + + for f in [:eig_trunc, :eigh_trunc] + @eval begin + A = Zeros(3, 3) + D, V = @constinferred $f(A; trunc=(; maxrank=2)) + @test A * V == V * D + @test size(D) == (2, 2) + @test size(V) == (3, 2) + @test D == Zeros(2, 2) + @test D isa Zeros + @test V == Eye(3, 2) + @test V isa Eye + end + end + + for f in [:eig_vals, :eigh_vals] + @eval begin + A = Zeros(3, 3) + D = @constinferred $f(A) + @test size(D) == (size(A, 1),) + @test iszero(D) + @test D isa Zeros + end + end + + for f in (qr_compact, left_orth) + A = Zeros(4, 3) + Q, R = @constinferred f(A) + @test Q * R == A + @test size(Q) == (4, 3) + @test size(R) == (3, 3) + @test Q == Matrix(I, (4, 3)) + @test Q isa Eye + @test iszero(R) + @test R isa Zeros + end + + A = Zeros(4, 3) + Q, R = @constinferred qr_full(A) + @test Q * R == A + @test size(Q) == (4, 4) + @test size(R) == (4, 3) + @test Q == I + @test Q isa Eye + @test iszero(R) + @test R isa Zeros + + A = Zeros(4, 3) + Q, R = @constinferred left_polar(A) + @test Q * R == A + @test size(Q) == (4, 3) + @test size(R) == (3, 3) + @test Q == Matrix(I, (4, 3)) + @test Q isa Eye + @test iszero(R) + @test R isa Zeros + + for f in (lq_compact, right_orth) + A = Zeros(3, 4) + L, Q = @constinferred f(A) + @test L * Q == A + @test L == Zeros(3, 3) + @test L isa Zeros + @test Q == Eye(3, 4) + @test Q isa Eye + end + + A = Zeros(3, 4) + L, Q = @constinferred lq_full(A) + @test L * Q == A + @test size(L) == (3, 4) + @test size(Q) == (4, 4) + @test iszero(L) + @test L isa Zeros + @test Q == I + @test Q isa Eye + + A = Zeros(3, 4) + L, Q = @constinferred lq_full(A) + @test L * Q == A + @test L === A + @test Q == Eye(4) + @test Q isa Eye + + A = Zeros(3, 4) + L, Q = @constinferred right_polar(A) + @test L * Q == A + @test size(L) == (3, 3) + @test size(Q) == (3, 4) + @test iszero(L) + @test L isa Zeros + @test Q == Matrix(I, (3, 4)) + @test Q isa Eye + + A = Zeros(3, 4) + U, S, V = @constinferred svd_compact(A) + @test U * S * V == A + @test size(U) == (3, 3) + @test size(S) == (3, 3) + @test size(V) == (3, 4) + @test iszero(S) + @test S isa Zeros + @test U == I + @test U isa Eye + @test V == Matrix(I, (3, 4)) + @test V isa Eye + + A = Zeros(3, 4) + U, S, V = @constinferred svd_full(A) + @test U * S * V == A + @test size(U) == (3, 3) + @test size(S) == (3, 4) + @test size(V) == (4, 4) + @test iszero(S) + @test S isa Zeros + @test U == I + @test U isa Eye + @test V == I + @test V isa Eye + + A = Zeros(3, 4) + U, S, V = @constinferred svd_trunc(A; trunc=(; maxrank=2)) + @test U * S * V == Eye(3, 2) * Zeros(2, 2) * Eye(2, 4) + @test size(U) == (3, 2) + @test size(S) == (2, 2) + @test size(V) == (2, 4) + @test S == Zeros(2, 2) + @test S isa Zeros + @test U == Eye(3, 2) + @test U isa Eye + @test V == Eye(2, 4) + @test V isa Eye + + A = Zeros(3, 4) + D = @constinferred svd_vals(A) + @test size(D) == (minimum(size(A)),) + @test iszero(D) + @test D isa Zeros +end + +@testset "Eye" begin + for f in [:eig_full, :eigh_full] + @eval begin + for A in (Eye(3), Eye(3, 3)) + local D, V = @constinferred $f(A) + @test A * V == V * D + @test size(D) == size(A) + @test size(V) == size(A) + @test V == I + @test typeof(D) === typeof(A) + @test V == I + @test typeof(V) === typeof(A) + end + end + end + + for f in [:eig_trunc, :eigh_trunc] + @eval begin + for A in (Eye(3), Eye(3, 3)) + local D, V = @constinferred $f(A; trunc=(; maxrank=2)) + @test A * V == V * D + @test size(D) == (2, 2) + @test size(V) == (3, 2) + @test D == Eye(2, 2) + @test D isa SquareEye + @test V == Eye(3, 2) + @test V isa Eye + end + end + end + + for f in [:eig_vals, :eigh_vals] + @eval begin + for A in (Eye(3), Eye(3, 3)) + local D = @constinferred $f(A) + @test size(D) == (size(A, 1),) + @test all(isone, D) + @test D isa Ones + end + end + end + + for f in (qr_compact, left_orth) + A = Eye(4, 3) + Q, R = @constinferred f(A) + @test Q * R == A + @test size(Q) == (4, 3) + @test size(R) == (3, 3) + @test Q == Matrix(I, (4, 3)) + @test Q isa Eye + @test R == I + @test R isa Eye + + A = Eye(3) + Q, R = @constinferred f(A) + @test Q * R == A + @test Q === A + @test R === A + end + + A = Eye(4, 3) + Q, R = @constinferred qr_full(A) + @test Q * R == A + @test size(Q) == (4, 4) + @test size(R) == (4, 3) + @test Q == I + @test Q isa Eye + @test R == Eye(4, 3) + @test R isa Eye + + A = Eye(3) + Q, R = @constinferred qr_full(A) + @test Q * R == A + @test Q === A + @test R === A + + A = Eye(4, 3) + Q, R = @constinferred left_polar(A) + @test Q * R == A + @test size(Q) == (4, 3) + @test size(R) == (3, 3) + @test Q == Matrix(I, (4, 3)) + @test Q isa Eye + @test R == I + @test R isa Eye + + A = Eye(3) + Q, R = @constinferred left_polar(A) + @test Q * R == A + @test Q === A + @test R === A + + for f in (lq_compact, right_orth) + A = Eye(3, 4) + L, Q = @constinferred lq_compact(A) + @test L * Q == A + @test size(L) == (3, 3) + @test size(Q) == (3, 4) + @test L == I + @test L isa Eye + @test Q == Matrix(I, (3, 4)) + @test Q isa Eye + + A = Eye(3) + L, Q = @constinferred lq_compact(A) + @test L * Q == A + @test L === A + @test Q === A + end + + A = Eye(3, 4) + L, Q = @constinferred lq_full(A) + @test L * Q == A + @test size(L) == (3, 4) + @test size(Q) == (4, 4) + @test L == Matrix(I, (3, 4)) + @test L isa Eye + @test Q == I + @test Q isa Eye + + A = Eye(3) + L, Q = @constinferred lq_full(A) + @test L * Q == A + @test L === A + @test Q === A + + A = Eye(3, 4) + L, Q = @constinferred right_polar(A) + @test L * Q == A + @test size(L) == (3, 3) + @test size(Q) == (3, 4) + @test L == I + @test L isa Eye + @test Q == Matrix(I, (3, 4)) + @test Q isa Eye + + A = Eye(3) + L, Q = @constinferred right_polar(A) + @test L * Q == A + @test L === A + @test Q === A + + A = Eye(3, 4) + U, S, V = @constinferred svd_compact(A) + @test U * S * V == A + @test size(U) == (3, 3) + @test size(S) == (3, 3) + @test size(V) == (3, 4) + @test S == I + @test S isa Eye + @test U == I + @test U isa Eye + @test V == Matrix(I, (3, 4)) + @test V isa Eye + + A = Eye(3) + U, S, V = @constinferred svd_compact(A) + @test U * S * V == A + @test U === A + @test S === A + @test V === A + + A = Eye(3, 4) + U, S, V = @constinferred svd_full(A) + @test U * S * V == A + @test size(U) == (3, 3) + @test size(S) == (3, 4) + @test size(V) == (4, 4) + @test S == Matrix(I, (3, 4)) + @test S isa Eye + @test U == I + @test U isa Eye + @test V == I + @test V isa Eye + + A = Eye(3) + U, S, V = @constinferred svd_full(A) + @test U * S * V == A + @test U === A + @test S === A + @test V === A + + A = Eye(3, 4) + U, S, V = @constinferred svd_trunc(A; trunc=(; maxrank=2)) + @test U * S * V == Eye(3, 2) * Eye(2, 2) * Eye(2, 4) + @test size(U) == (3, 2) + @test size(S) == (2, 2) + @test size(V) == (2, 4) + @test S == Eye(2, 2) + @test S isa Eye + @test U == Eye(3, 2) + @test U isa Eye + @test V == Eye(2, 4) + @test V isa Eye + + A = Eye(3, 4) + D = @constinferred svd_vals(A) + @test size(D) == (minimum(size(A)),) + @test all(isone, D) + @test D isa Ones +end