diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index e4637875..14c63a4e 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -1,7 +1,7 @@ module MatrixAlgebraKitGenericLinearAlgebraExt using MatrixAlgebraKit -using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, default_fixgauge +using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, default_fixgauge using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr! using LinearAlgebra: I, Diagonal, lmul! @@ -72,6 +72,11 @@ function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_Householde return _gla_householder_qr!(A, Q, R; alg.kwargs...) end +function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::GLA_HouseholderQR) + check_input(qr_null!, A, N, alg) + return _gla_householder_qr_null!(A, N; alg.kwargs...) +end + function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = false, blocksize = 1, pivoted = false) pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR.")) (blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR.")) @@ -109,6 +114,21 @@ function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = false, blocksi return Q, R end +function _gla_householder_qr_null!( + A::AbstractMatrix, N::AbstractMatrix; + positive = false, blocksize = 1, pivoted = false + ) + pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR.")) + (blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR.")) + m, n = size(A) + minmn = min(m, n) + fill!(N, zero(eltype(N))) + one!(view(N, (minmn + 1):m, 1:(m - minmn))) + Q̃, = qr!(A) + lmul!(Q̃, N) + return N +end + function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}} return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...)) end diff --git a/test/lq.jl b/test/lq.jl index 2f34e846..4da1fc7a 100644 --- a/test/lq.jl +++ b/test/lq.jl @@ -51,9 +51,9 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63) ) TestSuite.test_lq_algs(T, (m, n), LAPACK_LQ_ALGS) elseif T ∈ GenericFloats - TestSuite.test_lq(T, (m, n); test_null = false, test_pivoted = false, test_blocksize = false) + TestSuite.test_lq(T, (m, n); test_null = true, test_pivoted = false, test_blocksize = false) GLA_LQ_ALGS = (LQViaTransposedQR(GLA_HouseholderQR()),) - TestSuite.test_lq_algs(T, (m, n), GLA_LQ_ALGS; test_null = false) + TestSuite.test_lq_algs(T, (m, n), GLA_LQ_ALGS; test_null = true) end if m == n AT = Diagonal{T, Vector{T}} diff --git a/test/qr.jl b/test/qr.jl index 8b420fc6..3131349a 100644 --- a/test/qr.jl +++ b/test/qr.jl @@ -50,7 +50,7 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63) ) TestSuite.test_qr_algs(T, (m, n), LAPACK_QR_ALGS) elseif T ∈ GenericFloats - TestSuite.test_qr(T, (m, n); test_null = false, test_pivoted = false, test_blocksize = false) + TestSuite.test_qr(T, (m, n); test_null = true, test_pivoted = false, test_blocksize = false) GLA_QR_ALGS = (GLA_HouseholderQR(),) TestSuite.test_qr_algs(T, (m, n), GLA_QR_ALGS; test_null = false) end