Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
eb913cb
implement exponential
sanderdemeyer Nov 6, 2025
a3dc04d
update on exponential
sanderdemeyer Nov 12, 2025
c4564ee
Merge branch 'QuantumKitHub:main' into exponential
sanderdemeyer Nov 12, 2025
8dc3ecd
remove comment
sanderdemeyer Nov 12, 2025
d9fb748
Merge branch 'exponential' of https://github.com/sanderdemeyer/Matrix…
sanderdemeyer Nov 12, 2025
5095cdb
comments
sanderdemeyer Nov 13, 2025
89dfa23
change name of decompositions.jl to matrixfunctions.jl
sanderdemeyer Nov 19, 2025
996ecb5
revert name change
sanderdemeyer Nov 19, 2025
dc78eb0
Merge branch 'main' into exponential
sanderdemeyer Nov 19, 2025
f220035
general comments
sanderdemeyer Nov 20, 2025
c68afad
bug fix
sanderdemeyer Nov 20, 2025
95ddb06
avoid allocation in diagonal case
sanderdemeyer Nov 20, 2025
5d6f4f3
Merge branch 'main' into exponential
sanderdemeyer Nov 26, 2025
c8e811c
include exponentiali(tau, A)
sanderdemeyer Nov 26, 2025
0229417
remove simple test case and make the test more general
sanderdemeyer Nov 26, 2025
cbbf813
fix formatting
sanderdemeyer Nov 26, 2025
d08d545
add docs
sanderdemeyer Dec 1, 2025
720ada5
remove a bunch of allocations and clean up
lkdvos Dec 3, 2025
d738c22
Merge branch 'main' into exponential
lkdvos Dec 3, 2025
be111ea
introduce `map_diagonal` to simplify and relax types
lkdvos Dec 3, 2025
c760a47
rework tests
lkdvos Dec 3, 2025
d0d14e1
revert wrong filename changes
lkdvos Dec 3, 2025
cf98bd4
avoid running non-GPU tests through buildkite
lkdvos Dec 3, 2025
1536eb4
correct wrong in-place assumptions
lkdvos Dec 3, 2025
349800e
fixes part II
lkdvos Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ext/MatrixAlgebraKitGenericSchurExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,8 @@ function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration)
return GenericSchur.eigvals!(A)
end

function MatrixAlgebraKit.default_exponential_algorithm(E::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return ExponentialViaEig(GS_QRIteration(; kwargs...))
end

end
4 changes: 4 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ export left_polar, right_polar
export left_polar!, right_polar!
export left_orth, right_orth, left_null, right_null
export left_orth!, right_orth!, left_null!, right_null!
export exponential, exponential!

export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi
export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration
export LQViaTransposedQR
export PolarViaSVD, PolarNewton
export ExponentialViaLA, ExponentialViaEig, ExponentialViaEigh
export DiagonalAlgorithm
export NativeBlocked
export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar,
Expand Down Expand Up @@ -92,6 +94,7 @@ include("interface/gen_eig.jl")
include("interface/schur.jl")
include("interface/polar.jl")
include("interface/orthnull.jl")
include("interface/exponential.jl")

include("implementations/projections.jl")
include("implementations/truncation.jl")
Expand All @@ -104,6 +107,7 @@ include("implementations/gen_eig.jl")
include("implementations/schur.jl")
include("implementations/polar.jl")
include("implementations/orthnull.jl")
include("implementations/exponential.jl")

include("pullbacks/qr.jl")
include("pullbacks/lq.jl")
Expand Down
50 changes: 50 additions & 0 deletions src/implementations/exponential.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Inputs
# ------
function copy_input(::typeof(exponential), A::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A)
end

copy_input(::typeof(exponential), A::Diagonal) = copy(A)

function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
@assert expA isa AbstractMatrix
@check_size(expA, (m, m))
return @check_scalar(expA, A)
end

# Outputs
# -------
function initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm)
n = size(A, 1) # square check will happen later
expA = similar(A, (n, n))
return expA
end

# Implementation
# --------------
function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaLA)
copyto!(expA, LinearAlgebra.exp(A))
return expA
end

function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEigh)
D, V = eigh_full(A, alg.eigh_alg)
copyto!(expA, V * Diagonal(exp.(diagview(D))) * inv(V))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduced allocation strategy:

Suggested change
copyto!(expA, V * Diagonal(exp.(diagview(D))) * inv(V))
iV = inv(V)
map!(exp, diagview(D))
mul!(expA, rmul!(V, D), iV)

Copy link
Contributor Author

@sanderdemeyer sanderdemeyer Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has to be map!(exp, diagview(D), diagview(D)) instead of map!(exp, diagview(D)), but good suggestion otherwise. I have also added it for the ExponentialViaEig.
EDIT: the suggested change works only for Julia 1.12 onwards. That's why I will keep the version with
3 arguments.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just diagview(D) .= exp.(diagview(D))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that more efficient than the current code? If not, I'd prefer to keep it that way, since it feels a bit more natural to me.

return expA
end

function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEig)
D, V = eig_full(A, alg.eig_alg)
copyto!(expA, V * Diagonal(exp.(diagview(D))) * inv(V))
return expA
end

# Diagonal logic
# --------------
function exponential!(A::Diagonal, expA, alg::DiagonalAlgorithm)
check_input(exponential!, A, expA, alg)
copyto!(expA, Diagonal(LinearAlgebra.exp.(diagview(A))))
return expA
end
41 changes: 41 additions & 0 deletions src/interface/decompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,47 @@ Divide and Conquer algorithm.

const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi}

# ================================
# EXPONENTIAL ALGORITHMS
# ================================
"""
ExponentialViaLA()

Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`.
The `qr_alg` specifies which QR-decomposition implementation to use.
"""
@algdef ExponentialViaLA

"""
ExponentialViaEigh()

Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`.
The `qr_alg` specifies which QR-decomposition implementation to use.
"""
struct ExponentialViaEigh{A <: AbstractAlgorithm} <: AbstractAlgorithm
eigh_alg::A
end
function Base.show(io::IO, alg::ExponentialViaEigh)
print(io, "ExponentialViaEigh(")
_show_alg(io, alg.eigh_alg)
return print(io, ")")
end

"""
ExponentialViaEig()

Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`.
The `qr_alg` specifies which QR-decomposition implementation to use.
"""
struct ExponentialViaEig{A <: AbstractAlgorithm} <: AbstractAlgorithm
eig_alg::A
end
function Base.show(io::IO, alg::ExponentialViaEig)
print(io, "ExponentialViaEigh(")
_show_alg(io, alg.eig_alg)
return print(io, ")")
end

# Various consts and unions
# -------------------------

Expand Down
17 changes: 17 additions & 0 deletions src/interface/exponential.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Exponential functions
# --------------
@functiondef exponential
# @algdef exponential!

# Algorithm selection
# -------------------
default_exponential_algorithm(A; kwargs...) = default_exponential_algorithm(typeof(A); kwargs...)
function default_exponential_algorithm(T::Type; kwargs...)
return ExponentialViaLA(; kwargs...)
end

for f in (:exponential!,)
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_exponential_algorithm(A; kwargs...)
end
end
48 changes: 48 additions & 0 deletions test/exponential.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using MatrixAlgebraKit: diagview
using LinearAlgebra

BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
GenericFloats = (Float16, BigFloat, Complex{BigFloat})

@testset "exp! for T = $T" for T in BLASFloats
rng = StableRNG(123)
m = 2

A = randn(rng, T, m, m)
A = (A + A') / 2
D, V = @constinferred eigh_full(A)
algs = (ExponentialViaLA(), ExponentialViaEig(LAPACK_Simple()), ExponentialViaEigh(LAPACK_QRIteration()))
expA_LA = @constinferred exp(A)
@testset "algorithm $alg" for alg in algs
expA = similar(A)

@constinferred exponential!(copy(A), expA)
expA2 = @constinferred exponential(A; alg = alg)
@test expA expA_LA
@test expA2 expA

Dexp, Vexp = @constinferred eigh_full(expA)
@test diagview(Dexp) LinearAlgebra.exp.(diagview(D))
end
end

@testset "svd for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
rng = StableRNG(123)
atol = sqrt(eps(real(T)))
m = 54
Ad = randn(T, m)
A = Diagonal(Ad)

expA = similar(A)
@constinferred exponential!(copy(A), expA)
expA2 = @constinferred exponential(A; alg = DiagonalAlgorithm())
@test expA2 expA

D, V = @constinferred eig_full(A)
Dexp, Vexp = @constinferred eig_full(expA)
@test diagview(Dexp) LinearAlgebra.exp.(diagview(D))
end
28 changes: 28 additions & 0 deletions test/genericlinearalgebra/exponential.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using MatrixAlgebraKit: diagview
using LinearAlgebra

GenericFloats = (BigFloat, Complex{BigFloat})

@testset "exp! for T = $T" for T in GenericFloats
rng = StableRNG(123)
m = 2

A = randn(rng, T, m, m)
A = (A + A') / 2
D, V = @constinferred eigh_full(A)
algs = (ExponentialViaEigh(GLA_QRIteration()),)
@testset "algorithm $alg" for alg in algs
expA = similar(A)

@constinferred exponential!(copy(A), expA; alg)
expA2 = @constinferred exponential(A; alg)
@test expA2 ≈ expA

Dexp, Vexp = @constinferred eigh_full(expA)
@test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D))
end
end
29 changes: 29 additions & 0 deletions test/genericschur/exponential.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using MatrixAlgebraKit: diagview
using LinearAlgebra

GenericFloats = (BigFloat, Complex{BigFloat})

@testset "exp! for T = $T" for T in GenericFloats
rng = StableRNG(123)
m = 2

A = randn(rng, T, m, m)
D, V = @constinferred eig_full(A)
algs = (ExponentialViaEig(GS_QRIteration()),)
expA_LA = @constinferred exponential(A)
@testset "algorithm $alg" for alg in algs
expA = similar(A)

@constinferred exponential!(copy(A), expA)
expA2 = @constinferred exponential(A; alg = alg)
@test expA ≈ expA_LA
@test expA2 ≈ expA

Dexp, Vexp = @constinferred eig_full(expA)
@test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D))
end
end
9 changes: 9 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ if !is_buildkite
@safetestset "Image and Null Space" begin
include("orthnull.jl")
end
@safetestset "Exponential" begin
include("exponential.jl")
end
@safetestset "ChainRules" begin
include("chainrules.jl")
end
Expand Down Expand Up @@ -119,8 +122,14 @@ end
@safetestset "Hermitian Eigenvalue Decomposition" begin
include("genericlinearalgebra/eigh.jl")
end
@safetestset "Exponential" begin
include("genericlinearalgebra/exponential.jl")
end

using GenericSchur
@safetestset "General Eigenvalue Decomposition" begin
include("genericschur/eig.jl")
end
@safetestset "Exponential" begin
include("genericschur/exponential.jl")
end