-
Notifications
You must be signed in to change notification settings - Fork 5
Exponential #94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Exponential #94
Changes from 5 commits
eb913cb
a3dc04d
c4564ee
8dc3ecd
d9fb748
5095cdb
89dfa23
996ecb5
dc78eb0
f220035
c68afad
95ddb06
5d6f4f3
c8e811c
0229417
cbbf813
d08d545
720ada5
d738c22
be111ea
c760a47
d0d14e1
cf98bd4
1536eb4
349800e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,50 @@ | ||||||||||
| # Inputs | ||||||||||
| # ------ | ||||||||||
| function copy_input(::typeof(exponential), A::AbstractMatrix) | ||||||||||
| return copy!(similar(A, float(eltype(A))), A) | ||||||||||
| end | ||||||||||
|
|
||||||||||
| copy_input(::typeof(exponential), A::Diagonal) = copy(A) | ||||||||||
|
|
||||||||||
| function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) | ||||||||||
| m, n = size(A) | ||||||||||
| m == n || throw(DimensionMismatch("square input matrix expected")) | ||||||||||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| @assert expA isa AbstractMatrix | ||||||||||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| @check_size(expA, (m, m)) | ||||||||||
| return @check_scalar(expA, A) | ||||||||||
| end | ||||||||||
|
|
||||||||||
| # Outputs | ||||||||||
| # ------- | ||||||||||
| function initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm) | ||||||||||
| n = size(A, 1) # square check will happen later | ||||||||||
| expA = similar(A, (n, n)) | ||||||||||
lkdvos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| return expA | ||||||||||
| end | ||||||||||
|
|
||||||||||
| # Implementation | ||||||||||
| # -------------- | ||||||||||
| function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaLA) | ||||||||||
lkdvos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| copyto!(expA, LinearAlgebra.exp(A)) | ||||||||||
lkdvos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| return expA | ||||||||||
| end | ||||||||||
|
|
||||||||||
| function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEigh) | ||||||||||
lkdvos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| D, V = eigh_full(A, alg.eigh_alg) | ||||||||||
lkdvos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||
| copyto!(expA, V * Diagonal(exp.(diagview(D))) * inv(V)) | ||||||||||
|
||||||||||
| copyto!(expA, V * Diagonal(exp.(diagview(D))) * inv(V)) | |
| iV = inv(V) | |
| map!(exp, diagview(D)) | |
| mul!(expA, rmul!(V, D), iV) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has to be map!(exp, diagview(D), diagview(D)) instead of map!(exp, diagview(D)), but good suggestion otherwise. I have also added it for the ExponentialViaEig.
EDIT: the suggested change works only for Julia 1.12 onwards. That's why I will keep the version with
3 arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just diagview(D) .= exp.(diagview(D))?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that more efficient than the current code? If not, I'd prefer to keep it that way, since it feels a bit more natural to me.
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| # Exponential functions | ||
| # -------------- | ||
| @functiondef exponential | ||
| # @algdef exponential! | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Algorithm selection | ||
| # ------------------- | ||
| default_exponential_algorithm(A; kwargs...) = default_exponential_algorithm(typeof(A); kwargs...) | ||
| function default_exponential_algorithm(T::Type; kwargs...) | ||
| return ExponentialViaLA(; kwargs...) | ||
| end | ||
|
|
||
| for f in (:exponential!,) | ||
| @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} | ||
| return default_exponential_algorithm(A; kwargs...) | ||
| end | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| using MatrixAlgebraKit | ||
| using Test | ||
| using TestExtras | ||
| using StableRNGs | ||
| using MatrixAlgebraKit: diagview | ||
| using LinearAlgebra | ||
|
|
||
| BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) | ||
| GenericFloats = (Float16, BigFloat, Complex{BigFloat}) | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @testset "exp! for T = $T" for T in BLASFloats | ||
| rng = StableRNG(123) | ||
| m = 2 | ||
|
|
||
| A = randn(rng, T, m, m) | ||
| A = (A + A') / 2 | ||
| D, V = @constinferred eigh_full(A) | ||
| algs = (ExponentialViaLA(), ExponentialViaEig(LAPACK_Simple()), ExponentialViaEigh(LAPACK_QRIteration())) | ||
| expA_LA = @constinferred exp(A) | ||
| @testset "algorithm $alg" for alg in algs | ||
| expA = similar(A) | ||
|
|
||
| @constinferred exponential!(copy(A), expA) | ||
| expA2 = @constinferred exponential(A; alg = alg) | ||
| @test expA ≈ expA_LA | ||
| @test expA2 ≈ expA | ||
|
|
||
| Dexp, Vexp = @constinferred eigh_full(expA) | ||
| @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) | ||
| end | ||
| end | ||
|
|
||
| @testset "svd for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| rng = StableRNG(123) | ||
| atol = sqrt(eps(real(T))) | ||
| m = 54 | ||
| Ad = randn(T, m) | ||
| A = Diagonal(Ad) | ||
|
|
||
| expA = similar(A) | ||
| @constinferred exponential!(copy(A), expA) | ||
| expA2 = @constinferred exponential(A; alg = DiagonalAlgorithm()) | ||
| @test expA2 ≈ expA | ||
|
|
||
| D, V = @constinferred eig_full(A) | ||
| Dexp, Vexp = @constinferred eig_full(expA) | ||
| @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| using MatrixAlgebraKit | ||
| using Test | ||
| using TestExtras | ||
| using StableRNGs | ||
| using MatrixAlgebraKit: diagview | ||
| using LinearAlgebra | ||
|
|
||
| GenericFloats = (BigFloat, Complex{BigFloat}) | ||
|
|
||
| @testset "exp! for T = $T" for T in GenericFloats | ||
| rng = StableRNG(123) | ||
| m = 2 | ||
|
|
||
| A = randn(rng, T, m, m) | ||
| A = (A + A') / 2 | ||
| D, V = @constinferred eigh_full(A) | ||
| algs = (ExponentialViaEigh(GLA_QRIteration()),) | ||
| @testset "algorithm $alg" for alg in algs | ||
| expA = similar(A) | ||
|
|
||
| @constinferred exponential!(copy(A), expA; alg) | ||
| expA2 = @constinferred exponential(A; alg) | ||
| @test expA2 ≈ expA | ||
|
|
||
| Dexp, Vexp = @constinferred eigh_full(expA) | ||
| @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) | ||
| end | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| using MatrixAlgebraKit | ||
| using Test | ||
| using TestExtras | ||
| using StableRNGs | ||
| using MatrixAlgebraKit: diagview | ||
| using LinearAlgebra | ||
|
|
||
| GenericFloats = (BigFloat, Complex{BigFloat}) | ||
|
|
||
| @testset "exp! for T = $T" for T in GenericFloats | ||
| rng = StableRNG(123) | ||
| m = 2 | ||
|
|
||
| A = randn(rng, T, m, m) | ||
| D, V = @constinferred eig_full(A) | ||
| algs = (ExponentialViaEig(GS_QRIteration()),) | ||
| expA_LA = @constinferred exponential(A) | ||
| @testset "algorithm $alg" for alg in algs | ||
| expA = similar(A) | ||
|
|
||
| @constinferred exponential!(copy(A), expA) | ||
| expA2 = @constinferred exponential(A; alg = alg) | ||
| @test expA ≈ expA_LA | ||
| @test expA2 ≈ expA | ||
|
|
||
| Dexp, Vexp = @constinferred eig_full(expA) | ||
| @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) | ||
| end | ||
| end |
Uh oh!
There was an error while loading. Please reload this page.