From 958dd36488a3437b91734bca037e1ccc49a1206d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 02:55:12 -0500 Subject: [PATCH 1/3] Use Testsuite for AD tests --- Project.toml | 3 + .../MatrixAlgebraKitCUDAExt.jl | 21 +- .../MatrixAlgebraKitMooncakeExt.jl | 13 +- src/pullbacks/eig.jl | 9 +- src/pullbacks/lq.jl | 46 +- src/pullbacks/qr.jl | 46 +- src/pullbacks/svd.jl | 14 +- test/ad_utils.jl | 62 -- test/chainrules.jl | 592 +---------------- test/mooncake.jl | 604 +---------------- test/runtests.jl | 12 +- test/testsuite/TestSuite.jl | 4 + test/testsuite/ad_utils.jl | 328 ++++++++++ test/testsuite/chainrules.jl | 612 ++++++++++++++++++ test/testsuite/mooncake.jl | 480 ++++++++++++++ 15 files changed, 1557 insertions(+), 1289 deletions(-) delete mode 100644 test/ad_utils.jl create mode 100644 test/testsuite/ad_utils.jl create mode 100644 test/testsuite/chainrules.jl create mode 100644 test/testsuite/mooncake.jl diff --git a/Project.toml b/Project.toml index 0e9a26cc..350159d1 100644 --- a/Project.toml +++ b/Project.toml @@ -57,3 +57,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"] + +[sources] +CUDA = {url="https://github.com/JuliaGPU/CUDA.jl", rev="master"} diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 4d34dd9e..38f5b90f 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -3,7 +3,7 @@ module MatrixAlgebraKitCUDAExt using MatrixAlgebraKit using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! -using MatrixAlgebraKit: diagview, sign_safe +using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! @@ -183,4 +183,23 @@ function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix) return A, B end +MatrixAlgebraKit.default_pullback_rank_atol(A::AnyCuArray) = eps(norm(CuArray(A), Inf))^(3 / 4) +MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray) = MatrixAlgebraKit.iszerotangent(A) ? 0 : eps(norm(CuArray(A), Inf))^(3 / 4) +function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...) + As′ = filter(!MatrixAlgebraKit.iszerotangent, (A, As...)) + return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4) +end + +function LinearAlgebra.sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix) + #=m = size(A, 1) + n = size(B, 2) + I_n = fill!(similar(A, n), one(eltype(A))) + I_m = fill!(similar(B, m), one(eltype(B))) + L = kron(diagm(I_n), A) + kron(adjoint(B), diagm(I_m)) + x_vec = L \ -vec(C) + X = CuMatrix(reshape(x_vec, m, n))=# + hX = sylvester(collect(A), collect(B), collect(C)) + return CuArray(hX) +end + end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index f6feda8b..8aaf9e22 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt using Mooncake using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive using MatrixAlgebraKit -using MatrixAlgebraKit: inv_safe, diagview, copy_input +using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output using MatrixAlgebraKit: qr_pullback!, lq_pullback! using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! @@ -26,6 +26,17 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu return CoDual(Ac, dAc), copy_input_pb end +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(initialize_output), Any, Any, Any} +function Mooncake.rrule!!(::CoDual{typeof(initialize_output)}, f_df::CoDual, A_dA::CoDual, alg_dalg::CoDual) + output = initialize_output(Mooncake.primal(f_df), Mooncake.primal(A_dA), Mooncake.primal(alg_dalg)) + doutput = Mooncake.zero_tangent(output) + function initialize_output_pb(::NoRData) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return CoDual(output, doutput), initialize_output_pb +end + + # two-argument in-place factorizations like LQ, QR, EIG for (f!, f, pb, adj) in ( (:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint), diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 4a203f64..72d59c64 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -42,9 +42,12 @@ function eig_pullback!( mul!(view(VᴴΔV, :, indV), V', ΔV) mask = abs.(transpose(D) .- D) .< degeneracy_atol - Δgauge = norm(view(VᴴΔV, mask), Inf) - Δgauge ≤ gauge_atol || - @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + if isa(ΔA, Array) + # not GPU friendly... + Δgauge = norm(view(VᴴΔV, mask), Inf) + Δgauge ≤ gauge_atol || + @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index b30fe198..3dc311a6 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -36,22 +36,24 @@ function lq_pullback!( ΔA1 = view(ΔA, 1:p, :) ΔA2 = view(ΔA, (p + 1):m, :) - if minmn > p # case where A is rank-deficient - Δgauge = abs(zero(eltype(Q))) - if !iszerotangent(ΔQ) - # in this case the number Householder reflections will - # change upon small variations, and all of the remaining - # columns of ΔQ should be zero for a gauge-invariant - # cost function - ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) - Δgauge = max(Δgauge, norm(ΔQ2, Inf)) - end - if !iszerotangent(ΔL) - ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn) - Δgauge = max(Δgauge, norm(ΔL22, Inf)) + if isa(ΔA, Array) # not GPU friendly + if minmn > p # case where A is rank-deficient + Δgauge = abs(zero(eltype(Q))) + if !iszerotangent(ΔQ) + # in this case the number Householder reflections will + # change upon small variations, and all of the remaining + # columns of ΔQ should be zero for a gauge-invariant + # cost function + ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) + Δgauge = max(Δgauge, norm(ΔQ2, Inf)) + end + if !iszerotangent(ΔL) + ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn) + Δgauge = max(Δgauge, norm(ΔL22, Inf)) + end + Δgauge ≤ gauge_atol || + @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" end - Δgauge ≤ gauge_atol || - @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" end ΔQ̃ = zero!(similar(Q, (p, n))) @@ -69,9 +71,11 @@ function lq_pullback!( # how the full Q2 will change, but this we omit for now, and we consider # Q2' * ΔQ2 as a gauge dependent quantity. ΔQ2Q1ᴴ = ΔQ2 * Q1' - Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) - Δgauge ≤ gauge_atol || - @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + if isa(ΔA, Array) # not GPU friendly + Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) + Δgauge ≤ gauge_atol || + @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1) end end @@ -95,8 +99,10 @@ function lq_pullback!( Md = diagview(M) Md .= real.(Md) end - ldiv!(LowerTriangular(L11)', M) - ldiv!(LowerTriangular(L11)', ΔQ̃) + # not GPU friendly... + L11arr = typeof(L)(L11) + ldiv!(LowerTriangular(L11arr)', M) + ldiv!(LowerTriangular(L11arr)', ΔQ̃) ΔA1 = mul!(ΔA1, M, Q1, +1, 1) ΔA1 .+= ΔQ̃ return ΔA diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index 888029be..dfff1d2a 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -37,22 +37,24 @@ function qr_pullback!( ΔA1 = view(ΔA, :, 1:p) ΔA2 = view(ΔA, :, (p + 1):n) - if minmn > p # case where A is rank-deficient - Δgauge = abs(zero(eltype(Q))) - if !iszerotangent(ΔQ) - # in this case the number Householder reflections will - # change upon small variations, and all of the remaining - # columns of ΔQ should be zero for a gauge-invariant - # cost function - ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) - Δgauge = max(Δgauge, norm(ΔQ2, Inf)) - end - if !iszerotangent(ΔR) - ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n) - Δgauge = max(Δgauge, norm(ΔR22, Inf)) + if isa(ΔA, Array) # not GPU friendly + if minmn > p # case where A is rank-deficient + Δgauge = abs(zero(eltype(Q))) + if !iszerotangent(ΔQ) + # in this case the number Householder reflections will + # change upon small variations, and all of the remaining + # columns of ΔQ should be zero for a gauge-invariant + # cost function + ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) + Δgauge = max(Δgauge, norm(ΔQ2, Inf)) + end + if !iszerotangent(ΔR) + ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n) + Δgauge = max(Δgauge, norm(ΔR22, Inf)) + end + Δgauge ≤ gauge_atol || + @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" end - Δgauge ≤ gauge_atol || - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" end ΔQ̃ = zero!(similar(Q, (m, p))) @@ -69,9 +71,11 @@ function qr_pullback!( # how the full Q2 will change, but this we omit for now, and we consider # Q2' * ΔQ2 as a gauge dependent quantity. Q1dΔQ2 = Q1' * ΔQ2 - Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) - Δgauge ≤ gauge_atol || - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + if isa(ΔA, Array) # not GPU friendly + Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) + Δgauge ≤ gauge_atol || + @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1) end end @@ -95,8 +99,10 @@ function qr_pullback!( Md = diagview(M) Md .= real.(Md) end - rdiv!(M, UpperTriangular(R11)') - rdiv!(ΔQ̃, UpperTriangular(R11)') + # not GPU-friendly... + R11arr = typeof(R)(R11) + rdiv!(M, UpperTriangular(R11arr)') + rdiv!(ΔQ̃, UpperTriangular(R11arr)') ΔA1 = mul!(ΔA1, Q1, M, +1, 1) ΔA1 .+= ΔQ̃ return ΔA diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index a8f8b70c..415c7ec5 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -22,8 +22,8 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol """ function svd_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon(); - rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), - degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + rank_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])), + degeneracy_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])), gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) # Extract the SVD components @@ -33,7 +33,7 @@ function svd_pullback!( minmn = min(m, n) S = diagview(Smat) length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)")) - r = searchsortedlast(S, rank_atol; rev = true) # rank + r = findlast(s -> s ≥ rank_atol, S) # rank Ur = view(U, :, 1:r) Vᴴr = view(Vᴴ, 1:r, :) Sr = view(S, 1:r) @@ -71,9 +71,11 @@ function svd_pullback!( # check whether cotangents arise from gauge-invariance objective function mask = abs.(Sr' .- Sr) .< degeneracy_atol - Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) - Δgauge ≤ gauge_atol || - @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + if isa(ΔA, Array) # norm check not GPU friendly + Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) + Δgauge ≤ gauge_atol || + @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(Sr' .- Sr, degeneracy_atol) .+ (aUΔU .- aVΔV) .* inv_safe.(Sr' .+ Sr, degeneracy_atol) diff --git a/test/ad_utils.jl b/test/ad_utils.jl deleted file mode 100644 index 7a7cf39a..00000000 --- a/test/ad_utils.jl +++ /dev/null @@ -1,62 +0,0 @@ -function remove_svdgauge_dependence!( - ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true) - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 - mul!(ΔU, U, gaugepart, -1, 1) - return ΔU, ΔVᴴ -end -function remove_eiggauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = V' * ΔV - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V / (V' * V), gaugepart, -1, 1) - return ΔV -end -function remove_eighgauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = V' * ΔV - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V, gaugepart, -1, 1) - return ΔV -end -function stabilize_eigvals!(D::AbstractVector) - absD = abs.(D) - p = invperm(sortperm(absD)) # rank of abs(D) - # account for exact degeneracies in absolute value when having complex conjugate pairs - for i in 1:(length(D) - 1) - if absD[i] == absD[i + 1] # conjugate pairs will appear sequentially - p[p .>= p[i + 1]] .-= 1 # lower the rank of all higher ones - end - end - n = maximum(p) - # rescale eigenvalues so that they lie on distinct radii in the complex plane - # that are chosen randomly in non-overlapping intervals [k/n, (k+0.5)/n)] for k=1,...,n - radii = ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n - for i in 1:length(D) - D[i] = sign(D[i]) * radii[p[i]] - end - return D -end -function make_eig_matrix(rng, T, n) - A = randn(rng, T, n, n) - D, V = eig_full(A) - stabilize_eigvals!(diagview(D)) - Ac = V * D * inv(V) - return (T <: Real) ? real(Ac) : Ac -end -function make_eigh_matrix(rng, T, n) - A = project_hermitian!(randn(rng, T, n, n)) - D, V = eigh_full(A) - stabilize_eigvals!(diagview(D)) - return project_hermitian!(V * D * V') -end - -precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T))) diff --git a/test/chainrules.jl b/test/chainrules.jl index a8b2fd3b..c0ab618a 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,590 +1,18 @@ using MatrixAlgebraKit using Test -using TestExtras -using StableRNGs -using ChainRulesCore, ChainRulesTestUtils, Zygote -using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD -using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! -include("ad_utils.jl") +#BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI -for f in - ( - :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, - :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, - :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, - :left_polar, :right_polar, - ) - copy_f = Symbol(:copy_, f) - f! = Symbol(f, '!') - _hermitian = startswith(string(f), "eigh") - @eval begin - function $copy_f(input, alg) - if $_hermitian - input = (input + input') / 2 - end - return $f(input, alg) - end - function ChainRulesCore.rrule(::typeof($copy_f), input, alg) - output = MatrixAlgebraKit.initialize_output($f!, input, alg) - if $_hermitian - input = (input + input') / 2 - else - input = copy(input) - end - output, pb = ChainRulesCore.rrule($f!, input, output, alg) - return output, x -> (NoTangent(), pb(x)[2], NoTangent()) - end - end -end - -@timedtestset "QR AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - # qr_compact - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - alg = LAPACK_HouseholderQR(; positive = true) - Q, R = copy_qr_compact(A, alg) - ΔQ = randn(rng, T, m, minmn) - ΔR = randn(rng, T, minmn, n) - ΔR2 = UpperTriangular(randn(rng, T, minmn, minmn)) - ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) - test_rrule( - copy_qr_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔQ, ΔR), atol = atol, rtol = rtol - ) - test_rrule( - copy_qr_null, A, alg ⊢ NoTangent(); - output_tangent = ΔN, atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, qr_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔQ, ΔR), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ qr_compact, A; - fkwargs = (; positive = true), output_tangent = ΔQ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ qr_compact, A; - fkwargs = (; positive = true), output_tangent = ΔR, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, qr_null, A; - fkwargs = (; positive = true), output_tangent = ΔN, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - # qr_full - Q, R = copy_qr_full(A, alg) - Q1 = view(Q, 1:m, 1:minmn) - ΔQ = randn(rng, T, m, m) - ΔQ2 = view(ΔQ, :, (minmn + 1):m) - mul!(ΔQ2, Q1, Q1' * ΔQ2) - ΔR = randn(rng, T, m, n) - test_rrule( - copy_qr_full, A, alg ⊢ NoTangent(); - output_tangent = (ΔQ, ΔR), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, qr_full, A; - fkwargs = (; positive = true), output_tangent = (ΔQ, ΔR), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - if m > n - _, null_pb = Zygote.pullback(qr_null, A, alg) - @test_logs (:warn,) null_pb(randn(rng, T, m, max(0, m - minmn))) - _, full_pb = Zygote.pullback(qr_full, A, alg) - @test_logs (:warn,) full_pb((randn(rng, T, m, m), randn(rng, T, m, n))) - end - # rank-deficient A - r = minmn - 5 - A = randn(rng, T, m, r) * randn(rng, T, r, n) - Q, R = qr_compact(A, alg) - ΔQ = randn(rng, T, m, minmn) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - ΔQ2 .= 0 - ΔR = randn(rng, T, minmn, n) - view(ΔR, (r + 1):minmn, :) .= 0 - test_rrule( - copy_qr_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔQ, ΔR), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, qr_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔQ, ΔR), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "LQ AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - # lq_compact - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - alg = LAPACK_HouseholderLQ(; positive = true) - L, Q = copy_lq_compact(A, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q - test_rrule( - copy_lq_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔL, ΔQ), atol = atol, rtol = rtol - ) - test_rrule( - copy_lq_null, A, alg ⊢ NoTangent(); - output_tangent = ΔNᴴ, atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, lq_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔL, ΔQ), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ lq_compact, A; - fkwargs = (; positive = true), output_tangent = ΔL, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ lq_compact, A; - fkwargs = (; positive = true), output_tangent = ΔQ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, lq_null, A; - fkwargs = (; positive = true), output_tangent = ΔNᴴ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - # lq_full - L, Q = copy_lq_full(A, alg) - Q1 = view(Q, 1:minmn, 1:n) - ΔQ = randn(rng, T, n, n) - ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) - mul!(ΔQ2, ΔQ2 * Q1', Q1) - ΔL = randn(rng, T, m, n) - test_rrule( - copy_lq_full, A, alg ⊢ NoTangent(); - output_tangent = (ΔL, ΔQ), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, lq_full, A; - fkwargs = (; positive = true), output_tangent = (ΔL, ΔQ), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - if m < n - Nᴴ, null_pb = Zygote.pullback(lq_null, A, alg) - @test_logs (:warn,) null_pb(randn(rng, T, max(0, n - minmn), n)) - _, full_pb = Zygote.pullback(lq_full, A, alg) - @test_logs (:warn,) full_pb((randn(rng, T, m, n), randn(rng, T, n, n))) - end - # rank-deficient A - r = minmn - 5 - A = randn(rng, T, m, r) * randn(rng, T, r, n) - L, Q = lq_compact(A, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) - ΔQ2 .= 0 - view(ΔL, :, (r + 1):minmn) .= 0 - test_rrule( - copy_lq_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔL, ΔQ), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, lq_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔL, ΔQ), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "EIG AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eig_matrix(rng, T, m) - D, V = eig_full(A) - Ddiag = diagview(D) - ΔV = randn(rng, complex(T), m, m) - ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, complex(T), m, m) - ΔD2 = Diagonal(randn(rng, complex(T), m)) - for alg in (LAPACK_Simple(), LAPACK_Expert()) - test_rrule( - copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol - ) - test_rrule( - copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol - ) - test_rrule( - copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol - ) - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(Ddiag[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, eig_full, A; - output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eig_full, A; - output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ eig_full, A; - output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ eig_full, A; - output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eig_vals, A; - output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false - ) -end - -@timedtestset "EIGH AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eigh_matrix(rng, T, m) - D, V = eigh_full(A) - Ddiag = diagview(D) - ΔV = randn(rng, T, m, m) - ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, real(T), m, m) - ΔD2 = Diagonal(randn(rng, real(T), m)) - for alg in ( - LAPACK_QRIteration(), LAPACK_DivideAndConquer(), LAPACK_Bisection(), - LAPACK_MultipleRelativelyRobustRepresentations(), - ) - # copy_eigh_full includes a projector onto the Hermitian part of the matrix - test_rrule( - copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol - ) - test_rrule( - copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol - ) - test_rrule( - copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol - ) - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - # eigh_full does not include a projector onto the Hermitian part of the matrix - test_rrule( - config, eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eigh_vals ∘ Matrix ∘ Hermitian, A; - output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...) - for r in 1:4:m - trunc = truncrank(r; by = real) - ind = MatrixAlgebraKit.findtruncated(Ddiag, trunc) - test_rrule( - config, eigh_trunc2, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end - trunc = trunctol(; rtol = 1 / 2) - ind = MatrixAlgebraKit.findtruncated(Ddiag, trunc) - test_rrule( - config, eigh_trunc2, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) -end - -@timedtestset "SVD AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - for alg in (LAPACK_QRIteration(), LAPACK_DivideAndConquer()) - test_rrule( - copy_svd_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔU, ΔS, ΔVᴴ), atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_vals, A, alg ⊢ NoTangent(); - output_tangent = diagview(ΔS), atol, rtol - ) - for r in 1:4:minmn - truncalg = TruncatedAlgorithm(alg, truncrank(r)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) - dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) - dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, svd_compact, A; - output_tangent = (ΔU, ΔS, ΔVᴴ), atol = atol, rtol = rtol, - rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_compact, A; - output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol, - rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_vals, A; - output_tangent = diagview(ΔS), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - for r in 1:4:minmn - trunc = truncrank(r) - ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) - test_rrule( - config, svd_trunc, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_trunc_no_error, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end - trunc = trunctol(; atol = S[1, 1] / 2) - ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) - test_rrule( - config, svd_trunc, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_trunc_no_error, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "Polar AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) - m >= n && - test_rrule(copy_left_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) - m <= n && - test_rrule(copy_right_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - m >= n && test_rrule( - config, left_polar, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - m <= n && test_rrule( - config, right_polar, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "Orth and null with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, left_orth, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, left_orth, A; - fkwargs = (; alg = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - m >= n && - test_rrule( - config, left_orth, A; - fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - - ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - test_rrule( - config, left_null, A; - fkwargs = (; alg = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, - rrule_f = rrule_via_ad, check_inferred = false - ) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite - test_rrule( - config, right_orth, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, right_orth, A; fkwargs = (; alg = :lq), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - m <= n && - test_rrule( - config, right_orth, A; fkwargs = (; alg = :polar), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - test_rrule( - config, right_null, A; - fkwargs = (; alg = :lq), output_tangent = ΔNᴴ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) +m = 19 +for T in BLASFloats, n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite # doesn't work on GPU + TestSuite.test_chainrules(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end diff --git a/test/mooncake.jl b/test/mooncake.jl index 760102b1..feae6090 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -1,597 +1,25 @@ using MatrixAlgebraKit using Test -using TestExtras -using StableRNGs -using Mooncake, Mooncake.TestUtils -using Mooncake: rrule!! -using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc -using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! +using CUDA, AMDGPU -include("ad_utils.jl") +#BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI -make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) -make_mooncake_tangent(ΔA::Matrix{<:Real}) = ΔA -make_mooncake_tangent(ΔA::Vector{<:Real}) = ΔA -make_mooncake_tangent(ΔA::Matrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) -make_mooncake_tangent(ΔA::Vector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) -make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Real} = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) -make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_tangent(typeof(ΔD), map(make_mooncake_tangent, diagview(ΔD))) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite -make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), T...) +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" -make_mooncake_fdata(x) = make_mooncake_tangent(x) -make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) - -ETs = (Float32, ComplexF64) - -# no `alg` argument -function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) - dA_copy = make_mooncake_tangent(copy(ΔA)) - A_copy = copy(A) - dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy)) - copy_pb!!(rdata) - return dA_copy -end - -# `alg` argument -function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) - dA_copy = make_mooncake_tangent(copy(ΔA)) - A_copy = copy(A) - dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData())) - copy_pb!!(rdata) - return dA_copy -end - -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) - dA_inplace = make_mooncake_tangent(copy(ΔA)) - A_inplace = copy(A) - dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - # not every f! has a handwritten rrule!! - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} - has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) - if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) - else - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) - end - inplace_pb!!(rdata) - return dA_inplace -end - -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) - dA_inplace = make_mooncake_tangent(copy(ΔA)) - A_inplace = copy(A) - dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - # not every f! has a handwritten rrule!! - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} - has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) - if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) - else - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) - end - inplace_pb!!(rdata) - return dA_inplace -end - -""" - test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) - -Compare the result of running the *in-place, mutating* function `f!`'s reverse rule -with the result of running its *non-mutating* partner function `f`'s reverse rule. -We must compare directly because many of the mutating functions modify `A` as a -scratch workspace, making testing `f!` against finite differences infeasible. - -The arguments to this function are: - - `f!` the mutating, in-place version of the function (accepts `args` for the function result) - - `f` the non-mutating version of the function (does not accept `args` for the function result) - - `A` the input matrix to factorize - - `args` preallocated output for `f!` (e.g. `Q` and `R` matrices for `qr_compact!`) - - `Δargs` precomputed derivatives of `args` for pullbacks of `f` and `f!`, to ensure they receive the same input - - `alg` optional algorithm keyword argument - - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) -""" -function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) - f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg) - sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - rrule = Mooncake.build_rrule(rvs_interp, sig) - ΔA = randn(rng, eltype(A), size(A)) - - dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) - dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) - - dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] - dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] - @test dA_inplace_ ≈ dA_copy_ - return -end - -@timedtestset "QR AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - @testset for alg in ( - LAPACK_HouseholderQR(), - LAPACK_HouseholderQR(; positive = true), - ) - @testset "qr_compact" begin - QR = qr_compact(A, alg) - Q = randn(rng, T, m, minmn) - R = randn(rng, T, minmn, n) - Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_compact!, qr_compact, A, (Q, R), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) - end - @testset "qr_null" begin - Q, R = qr_compact(A, alg) - ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) - N = qr_null(A, alg) - dN = make_mooncake_tangent(copy(ΔN)) - Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dN, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) - end - @testset "qr_full" begin - Q, R = qr_full(A, alg) - Q1 = view(Q, 1:m, 1:minmn) - ΔQ = randn(rng, T, m, m) - ΔQ2 = view(ΔQ, :, (minmn + 1):m) - mul!(ΔQ2, Q1, Q1' * ΔQ2) - ΔR = randn(rng, T, m, n) - dQ = make_mooncake_tangent(copy(ΔQ)) - dR = make_mooncake_tangent(copy(ΔR)) - dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg) - end - @testset "qr_compact - rank-deficient A" begin - r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) - Q, R = qr_compact(Ard, alg) - QR = (Q, R) - ΔQ = randn(rng, T, m, minmn) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - ΔQ2 .= 0 - ΔR = randn(rng, T, minmn, n) - view(ΔR, (r + 1):minmn, :) .= 0 - dQ = make_mooncake_tangent(copy(ΔQ)) - dR = make_mooncake_tangent(copy(ΔR)) - dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg) - end - end - end -end - -@timedtestset "LQ AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - @testset for alg in ( - LAPACK_HouseholderLQ(), - LAPACK_HouseholderLQ(; positive = true), - ) - @testset "lq_compact" begin - L, Q = lq_compact(A, alg) - Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) - end - @testset "lq_null" begin - L, Q = lq_compact(A, alg) - ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q - Nᴴ = randn(rng, T, max(0, n - minmn), n) - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dNᴴ, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) - end - @testset "lq_full" begin - L, Q = lq_full(A, alg) - Q1 = view(Q, 1:minmn, 1:n) - ΔQ = randn(rng, T, n, n) - ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) - mul!(ΔQ2, ΔQ2 * Q1', Q1) - ΔL = randn(rng, T, m, n) - dL = make_mooncake_tangent(ΔL) - dQ = make_mooncake_tangent(ΔQ) - dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg) - end - @testset "lq_compact - rank-deficient A" begin - r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) - L, Q = lq_compact(Ard, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) - ΔQ2 .= 0 - view(ΔL, :, (r + 1):minmn) .= 0 - dL = make_mooncake_tangent(ΔL) - dQ = make_mooncake_tangent(ΔQ) - dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg) - end - end - end -end - -@timedtestset "EIG AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eig_matrix(rng, T, m) - DV = eig_full(A) - D, V = DV - Ddiag = diagview(D) - ΔV = randn(rng, complex(T), m, m) - ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, complex(T), m, m) - ΔD2 = Diagonal(randn(rng, complex(T), m)) - - dD = make_mooncake_tangent(ΔD2) - dV = make_mooncake_tangent(ΔV) - dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) - # compute the dA corresponding to the above dD, dV - @testset for alg in ( - LAPACK_Simple(), - #LAPACK_Expert(), # expensive on CI - ) - @testset "eig_full" begin - Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) - end - @testset "eig_vals" begin - Mooncake.TestUtils.test_rule(rng, eig_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) - end - @testset "eig_trunc" begin - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) - end - truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) - end - end -end - -function copy_eigh_full(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_full(A, alg; kwargs...) -end - -function copy_eigh_full!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_full!(A, DV, alg; kwargs...) -end - -function copy_eigh_vals(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_vals(A, alg; kwargs...) -end - -function copy_eigh_vals!(A, D, alg; kwargs...) - A = (A + A') / 2 - return eigh_vals!(A, D, alg; kwargs...) -end - -function copy_eigh_trunc(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc(A, alg; kwargs...) -end - -function copy_eigh_trunc!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc!(A, DV, alg; kwargs...) -end - -function copy_eigh_trunc_no_error(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc_no_error(A, alg; kwargs...) -end - -function copy_eigh_trunc_no_error!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc_no_error!(A, DV, alg; kwargs...) -end - -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) - -@timedtestset "EIGH AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eigh_matrix(rng, T, m) - D, V = eigh_full(A) - Ddiag = diagview(D) - ΔV = randn(rng, T, m, m) - ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, real(T), m, m) - ΔD2 = Diagonal(randn(rng, real(T), m)) - dD = make_mooncake_tangent(ΔD2) - dV = make_mooncake_tangent(ΔV) - dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) - @testset for alg in ( - LAPACK_QRIteration(), - #LAPACK_DivideAndConquer(), - #LAPACK_Bisection(), - #LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI - ) - @testset "eigh_full" begin - Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg) - end - @testset "eigh_vals" begin - Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg) - end - @testset "eigh_trunc" begin - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) - end - truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) - end +m = 19 +for T in BLASFloats, n in (17, m, 23) + TestSuite.seed_rng!(123) + if CUDA.functional() + TestSuite.test_mooncake(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end -end - -@timedtestset "SVD AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - @testset for alg in ( - LAPACK_QRIteration(), - #LAPACK_DivideAndConquer(), # expensive on CI - ) - @testset "svd_compact" begin - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - dS = make_mooncake_tangent(ΔS2) - dU = make_mooncake_tangent(ΔU) - dVᴴ = make_mooncake_tangent(ΔVᴴ) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dU, dS, dVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_compact, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), alg) - end - @testset "svd_full" begin - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - ΔUfull = zeros(T, m, m) - ΔSfull = zeros(real(T), m, n) - ΔVᴴfull = zeros(T, n, n) - U, S, Vᴴ = svd_full(A) - view(ΔUfull, :, 1:minmn) .= ΔU - view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ - diagview(ΔSfull)[1:minmn] .= diagview(ΔS2) - dS = make_mooncake_tangent(ΔSfull) - dU = make_mooncake_tangent(ΔUfull) - dVᴴ = make_mooncake_tangent(ΔVᴴfull) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔUfull, ΔSfull, ΔVᴴfull)), dU, dS, dVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_full!, svd_full, A, (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull), alg) - end - @testset "svd_vals" begin - Mooncake.TestUtils.test_rule(rng, svd_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) - S = svd_vals(A, alg) - test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, randn(rng, real(T), minmn), alg) - end - @testset "svd_trunc" begin - @testset for r in 1:4:minmn - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - truncalg = TruncatedAlgorithm(alg, truncrank(r)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - dStrunc = make_mooncake_tangent(ΔStrunc) - dUtrunc = make_mooncake_tangent(ΔUtrunc) - dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) - ϵ = zero(real(T)) - dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) - end - @testset "trunctol" begin - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - dStrunc = make_mooncake_tangent(ΔStrunc) - dUtrunc = make_mooncake_tangent(ΔUtrunc) - dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) - ϵ = zero(real(T)) - dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) - end - end - end - end -end - -@timedtestset "Polar AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - @testset for alg in PolarViaSVD.( - ( - LAPACK_QRIteration(), - #LAPACK_DivideAndConquer(), # expensive on CI - ) - ) - if m >= n - WP = left_polar(A, alg) - Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, left_polar!, left_polar, A, WP, (randn(rng, T, m, n), randn(rng, T, n, n)), alg) - elseif m <= n - PWᴴ = right_polar(A, alg) - Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, (randn(rng, T, m, m), randn(rng, T, m, n)), alg) - end - end - end -end - -left_orth_qr(X) = left_orth(X; alg = :qr) -left_orth_polar(X) = left_orth(X; alg = :polar) -left_null_qr(X) = left_null(X; alg = :qr) -right_orth_lq(X) = right_orth(X; alg = :lq) -right_orth_polar(X) = right_orth(X; alg = :polar) -right_null_lq(X) = right_null(X; alg = :lq) - -MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) - -@timedtestset "Orth and null with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - VC = left_orth(A) - CVᴴ = right_orth(A) - Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, left_orth!, left_orth, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, right_orth!, right_orth, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - - Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - if m >= n - Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - end - - N = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - dN = make_mooncake_tangent(ΔN) - Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) - test_pullbacks_match(rng, ((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) - - Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - - if m <= n - Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - end - - Nᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) - test_pullbacks_match(rng, ((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) + #=if AMDGPU.functional() + TestSuite.test_mooncake(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end=# # not yet supported + if !is_buildkite + TestSuite.test_mooncake(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end diff --git a/test/runtests.jl b/test/runtests.jl index 7325d410..ea68929e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,12 +25,6 @@ if !is_buildkite @safetestset "Image and Null Space" begin include("orthnull.jl") end - @safetestset "Mooncake" begin - include("mooncake.jl") - end - @safetestset "ChainRules" begin - include("chainrules.jl") - end @safetestset "MatrixAlgebraKit.jl" begin @safetestset "Code quality (Aqua.jl)" begin using MatrixAlgebraKit @@ -67,6 +61,12 @@ end @safetestset "Schur Decomposition" begin include("schur.jl") end +@safetestset "Mooncake" begin + include("mooncake.jl") +end +@safetestset "ChainRules" begin + include("chainrules.jl") +end using CUDA if CUDA.functional() diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index deee4e12..873bd65e 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -69,10 +69,14 @@ is_positive(alg::MatrixAlgebraKit.ROCSOLVER_HouseholderQR) = alg.positive is_positive(alg::MatrixAlgebraKit.LQViaTransposedQR) = is_positive(alg.qr_alg) is_pivoted(alg::MatrixAlgebraKit.LQViaTransposedQR) = is_pivoted(alg.qr_alg) +include("ad_utils.jl") + include("qr.jl") include("lq.jl") include("polar.jl") include("projections.jl") include("schur.jl") +include("mooncake.jl") +include("chainrules.jl") end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl new file mode 100644 index 00000000..4d4108b9 --- /dev/null +++ b/test/testsuite/ad_utils.jl @@ -0,0 +1,328 @@ +function remove_svdgauge_dependence!( + ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) + ) + gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true) + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end +function remove_eiggauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) + ) + gaugepart = V' * ΔV + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end +function remove_eighgauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) + ) + gaugepart = V' * ΔV + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V, gaugepart, -1, 1) + return ΔV +end + +function stabilize_eigvals!(D::AbstractVector) + absD = collect(abs.(D)) + p = invperm(sortperm(collect(absD))) # rank of abs(D) + # account for exact degeneracies in absolute value when having complex conjugate pairs + for i in 1:(length(D) - 1) + if absD[i] == absD[i + 1] # conjugate pairs will appear sequentially + p[p .>= p[i + 1]] .-= 1 # lower the rank of all higher ones + end + end + n = maximum(p) + # rescale eigenvalues so that they lie on distinct radii in the complex plane + # that are chosen randomly in non-overlapping intervals [k/n, (k+0.5)/n)] for k=1,...,n + radii = ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n + hD = sign.(collect(D)) .* radii[p] + copyto!(D, hD) + return D +end +function make_eig_matrix(T, sz) + A = instantiate_matrix(T, sz) + D, V = eig_full(A) + stabilize_eigvals!(diagview(D)) + Ac = V * D * inv(V) + return (T <: Real) ? real(Ac) : Ac +end +function make_eigh_matrix(T, sz) + A = project_hermitian!(instantiate_matrix(T, sz)) + D, V = eigh_full(A) + stabilize_eigvals!(diagview(D)) + return project_hermitian!(V * D * V') +end + +function ad_qr_compact_setup(A) + m, n = size(A) + minmn = min(m, n) + QR = qr_compact(A) + T = eltype(A) + ΔQ = randn!(similar(A, T, m, minmn)) + ΔR = randn!(similar(A, T, minmn, n)) + return QR, (ΔQ, ΔR) +end + +function ad_qr_null_setup(A) + m, n = size(A) + minmn = min(m, n) + Q, R = qr_compact(A) + T = eltype(A) + ΔN = Q * randn!(similar(A, T, minmn, max(0, m - minmn))) + N = qr_null(A) + return N, ΔN +end + +function ad_qr_full_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + Q, R = qr_full(A) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn!(similar(A, T, m, m)) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn!(similar(A, T, m, n)) + return (Q, R), (ΔQ, ΔR) +end + +function ad_qr_rd_compact_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + r = minmn - 5 + Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) + Q, R = qr_compact(Ard) + QR = (Q, R) + ΔQ = randn!(similar(A, T, m, minmn)) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = randn!(similar(A, T, minmn, n)) + view(ΔR, (r + 1):minmn, :) .= 0 + return (Q, R), (ΔQ, ΔR) +end + +function ad_lq_compact_setup(A) + m, n = size(A) + minmn = min(m, n) + LQ = lq_compact(A) + T = eltype(A) + ΔL = randn!(similar(A, T, m, minmn)) + ΔQ = randn!(similar(A, T, minmn, n)) + return LQ, (ΔL, ΔQ) +end + +function ad_lq_null_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + L, Q = lq_compact(A) + ΔNᴴ = randn!(similar(A, T, max(0, n - minmn), minmn)) * Q + Nᴴ = randn!(similar(A, T, max(0, n - minmn), n)) + return Nᴴ, ΔNᴴ +end + +function ad_lq_full_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + L, Q = lq_full(A) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = randn!(similar(A, T, n, n)) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + ΔL = randn!(similar(A, T, m, n)) + return (L, Q), (ΔL, ΔQ) +end + +function ad_lq_rd_compact_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + r = minmn - 5 + Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) + L, Q = lq_compact(Ard) + ΔL = randn!(similar(A, T, m, minmn)) + ΔQ = randn!(similar(A, T, minmn, n)) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + return (L, Q), (ΔL, ΔQ) +end + +function ad_eig_full_setup(A) + m, n = size(A) + T = eltype(A) + DV = eig_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn!(similar(A, complex(T), m, m)) + ΔV = remove_eiggauge_dependence!(ΔV, D, V) + ΔD = randn!(similar(A, complex(T), m, m)) + ΔD2 = Diagonal(randn!(similar(A, complex(T), m))) + return DV, (ΔD, ΔV), (ΔD2, ΔV) +end + +function ad_eig_vals_setup(A) + m, n = size(A) + T = eltype(A) + D = eig_vals(A) + ΔD = randn!(similar(A, complex(T), m)) + return D, ΔD +end + +function ad_eig_trunc_setup(A, truncalg) + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + Dtrunc = Diagonal(diagview(DV[1])[ind]) + Vtrunc = DV[2][:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2V[1])[ind]) + ΔVtrunc = ΔDV[2][:, ind] + return DV, (Dtrunc, Vtrunc), ΔD2V, (ΔDtrunc, ΔVtrunc) +end + +function ad_eigh_full_setup(A) + m, n = size(A) + T = eltype(A) + DV = eigh_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn!(similar(A, T, m, m)) + ΔV = remove_eighgauge_dependence!(ΔV, D, V) + ΔD = randn!(similar(A, real(T), m, m)) + ΔD2 = Diagonal(randn!(similar(A, real(T), m))) + return DV, (ΔD, ΔV), (ΔD2, ΔV) +end + +function ad_eigh_vals_setup(A) + m, n = size(A) + T = eltype(A) + D = eigh_vals(A) + ΔD = randn!(similar(A, real(T), m)) + return D, ΔD +end + +function ad_eigh_trunc_setup(A, truncalg) + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + Dtrunc = Diagonal(diagview(DV[1])[ind]) + Vtrunc = DV[2][:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2V[1])[ind]) + ΔVtrunc = ΔDV[2][:, ind] + return DV, (Dtrunc, Vtrunc), ΔD2V, (ΔDtrunc, ΔVtrunc) +end + +function ad_svd_compact_setup(A) + m, n = size(A) + T = eltype(A) + minmn = min(m, n) + ΔU = randn!(similar(A, T, m, minmn)) + ΔS = randn!(similar(A, real(T), minmn, minmn)) + ΔS2 = Diagonal(randn!(similar(A, real(T), minmn))) + ΔVᴴ = randn!(similar(A, T, minmn, n)) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), (ΔU, ΔS2, ΔVᴴ) +end + +function ad_svd_full_setup(A) + m, n = size(A) + T = eltype(A) + minmn = min(m, n) + ΔU = randn!(similar(A, T, m, minmn)) + ΔS = randn!(similar(A, real(T), minmn, minmn)) + ΔS2 = Diagonal(randn!(similar(A, real(T), minmn))) + ΔVᴴ = randn!(similar(A, T, minmn, n)) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + ΔUfull = similar(A, T, m, m) + ΔUfull .= zero(T) + ΔSfull = similar(A, real(T), m, n) + ΔSfull .= zero(real(T)) + ΔVᴴfull = similar(A, T, n, n) + ΔVᴴfull .= zero(T) + U, S, Vᴴ = svd_full(A) + view(ΔUfull, :, 1:minmn) .= ΔU + view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ + diagview(ΔSfull)[1:minmn] .= diagview(ΔS2) + return (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull) +end + +function ad_svd_vals_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + S = svd_vals(A) + ΔS = randn!(similar(A, real(T), minmn)) + return S, ΔS +end + +function ad_svd_trunc_setup(A, truncalg) + USVᴴ, ΔUSVᴴ, ΔUS2Vᴴ = ad_svd_compact_setup(A) + ind = MatrixAlgebraKit.findtruncated(diagview(USVᴴ[2]), truncalg.trunc) + Strunc = Diagonal(diagview(USVᴴ[2])[ind]) + Utrunc = USVᴴ[1][:, ind] + Vᴴtrunc = USVᴴ[3][ind, :] + ΔStrunc = Diagonal(diagview(ΔUS2Vᴴ[2])[ind]) + ΔUtrunc = ΔUSVᴴ[1][:, ind] + ΔVᴴtrunc = ΔUSVᴴ[3][ind, :] + return USVᴴ, ΔUS2Vᴴ, (ΔUtrunc, ΔStrunc, ΔVᴴtrunc) +end + +function ad_left_polar_setup(A) + m, n = size(A) + T = eltype(A) + WP = left_polar(A) + ΔWP = (randn!(similar(A, T, m, n)), randn!(similar(A, T, n, n))) + return WP, ΔWP +end + +function ad_right_polar_setup(A) + m, n = size(A) + T = eltype(A) + PWᴴ = right_polar(A) + ΔPWᴴ = (randn!(similar(A, T, m, m)), randn!(similar(A, T, m, n))) + return PWᴴ, ΔPWᴴ +end + +function ad_left_orth_setup(A) + m, n = size(A) + T = eltype(A) + VC = left_orth(A) + ΔVC = (randn!(similar(A, T, size(VC[1])...)), randn!(similar(A, T, size(VC[2])...))) + return VC, ΔVC +end + +function ad_left_null_setup(A) + m, n = size(A) + T = eltype(A) + N = left_orth(A; alg = :qr)[1] * randn!(similar(A, T, min(m, n), m - min(m, n))) + ΔN = left_orth(A; alg = :qr)[1] * randn!(similar(A, T, min(m, n), m - min(m, n))) + return N, ΔN +end + +function ad_right_orth_setup(A) + m, n = size(A) + T = eltype(A) + CVᴴ = right_orth(A) + ΔCVᴴ = (randn!(similar(A, T, size(CVᴴ[1])...)), randn!(similar(A, T, size(CVᴴ[2])...))) + return CVᴴ, ΔCVᴴ +end + +function ad_right_null_setup(A) + m, n = size(A) + T = eltype(A) + Nᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] + ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] + return Nᴴ, ΔNᴴ +end diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl new file mode 100644 index 00000000..625ff6a3 --- /dev/null +++ b/test/testsuite/chainrules.jl @@ -0,0 +1,612 @@ +using MatrixAlgebraKit +using ChainRulesCore, ChainRulesTestUtils, Zygote +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +for f in + ( + :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, + :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, + :eig_trunc_no_error, :eigh_trunc_no_error, + :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, + :left_polar, :right_polar, + ) + copy_f = Symbol(:cr_copy_, f) + f! = Symbol(f, '!') + _hermitian = startswith(string(f), "eigh") + @eval begin + function $copy_f(input, alg) + if $_hermitian + input = (input + input') / 2 + end + return $f(input, alg) + end + function ChainRulesCore.rrule(::typeof($copy_f), input, alg) + output = MatrixAlgebraKit.initialize_output($f!, input, alg) + if $_hermitian + input = (input + input') / 2 + else + input = copy(input) + end + output, pb = ChainRulesCore.rrule($f!, input, output, alg) + return output, x -> (NoTangent(), pb(x)[2], NoTangent()) + end + end +end + +function test_chainrules(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Chainrules AD $summary_str" begin + test_chainrules_qr(T, sz; kwargs...) + test_chainrules_lq(T, sz; kwargs...) + if length(sz) == 1 || sz[1] == sz[2] + test_chainrules_eig(T, sz; kwargs...) + test_chainrules_eigh(T, sz; kwargs...) + end + test_chainrules_svd(T, sz; kwargs...) + test_chainrules_polar(T, sz; kwargs...) + test_chainrules_orthnull(T, sz; kwargs...) + end +end + +function test_chainrules_qr( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "QR ChainRules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_qr_algorithm(A) + @testset "qr_compact" begin + QR, ΔQR = ad_qr_compact_setup(A) + ΔQ, ΔR = ΔQR + test_rrule( + cr_copy_qr_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔQR, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_compact, A; + fkwargs = (; positive = true), output_tangent = ΔQR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ qr_compact, A; + fkwargs = (; positive = true), output_tangent = ΔQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ qr_compact, A; + fkwargs = (; positive = true), output_tangent = ΔR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "qr_null" begin + N, ΔN = ad_qr_null_setup(A) + test_rrule( + cr_copy_qr_null, A, alg ⊢ NoTangent(); + output_tangent = ΔN, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_null, A; + fkwargs = (; positive = true), output_tangent = ΔN, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m, n = size(A) + end + @testset "qr_full" begin + QR, ΔQR = ad_qr_full_setup(A) + test_rrule( + cr_copy_qr_full, A, alg ⊢ NoTangent(); + output_tangent = ΔQR, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_full, A; + fkwargs = (; positive = true), output_tangent = ΔQR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m, n = size(A) + end + @testset "qr_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + QR, ΔQR = ad_qr_rd_compact_setup(Ard) + ΔQ, ΔR = ΔQR + test_rrule( + cr_copy_qr_compact, Ard, alg ⊢ NoTangent(); + output_tangent = ΔQR, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_compact, Ard; + fkwargs = (; positive = true), output_tangent = ΔQR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_lq( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "LQ Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_lq_algorithm(A) + @testset "lq_compact" begin + LQ, ΔLQ = ad_lq_compact_setup(A) + ΔL, ΔQ = ΔLQ + test_rrule( + cr_copy_lq_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔLQ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_compact, A; + fkwargs = (; positive = true), output_tangent = ΔLQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ lq_compact, A; + fkwargs = (; positive = true), output_tangent = ΔL, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ lq_compact, A; + fkwargs = (; positive = true), output_tangent = ΔQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "lq_null" begin + Nᴴ, ΔNᴴ = ad_lq_null_setup(A) + test_rrule( + cr_copy_lq_null, A, alg ⊢ NoTangent(); + output_tangent = ΔNᴴ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_null, A; + fkwargs = (; positive = true), output_tangent = ΔNᴴ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "lq_full" begin + LQ, ΔLQ = ad_lq_full_setup(A) + test_rrule( + cr_copy_lq_full, A, alg ⊢ NoTangent(); + output_tangent = ΔLQ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_full, A; + fkwargs = (; positive = true), output_tangent = ΔLQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "lq_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + LQ, ΔLQ = ad_lq_rd_compact_setup(Ard) + test_rrule( + cr_copy_lq_compact, Ard, alg ⊢ NoTangent(); + output_tangent = ΔLQ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_compact, Ard; + fkwargs = (; positive = true), output_tangent = ΔLQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_eig( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIG Chainrules AD rules $summary_str" begin + A = make_eig_matrix(T, sz) + m = size(A, 1) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_eig_algorithm(A) + @testset "eig_full" begin + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) + ΔD, ΔV = ΔDV + test_rrule( + cr_copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol + ) + test_rrule( + cr_copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔD2V, atol, rtol + ) + test_rrule( + config, eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔDV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔD2V, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eig_vals" begin + D, ΔD = ad_eig_vals_setup(A) + test_rrule( + cr_copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol + ) + test_rrule( + config, eig_vals, A, alg ⊢ NoTangent(); + output_tangent = ΔD, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eig_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + test_rrule( + cr_copy_eig_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_eig_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔDVtrunc, atol = atol, rtol = rtol + ) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + test_rrule( + cr_copy_eig_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_eig_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔDVtrunc, atol = atol, rtol = rtol + ) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end +end + +function test_chainrules_eigh( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIGH ChainRules AD rules $summary_str" begin + A = make_eigh_matrix(T, sz) + m = size(A, 1) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_eigh_algorithm(A) + # copy_eigh_xxxx includes a projector onto the Hermitian part of the matrix + @testset "eigh_full" begin + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) + ΔD, ΔV = ΔDV + test_rrule( + cr_copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol + ) + test_rrule( + cr_copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔD2V, atol, rtol + ) + # eigh_full does not include a projector onto the Hermitian part of the matrix + test_rrule( + config, eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔDV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔD2V, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eigh_vals" begin + D, ΔD = ad_eigh_vals_setup(A) + test_rrule( + cr_copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol + ) + test_rrule( + config, eigh_vals ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔD, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eigh_trunc" begin + eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...) + eigh_trunc_no_error2(A; kwargs...) = eigh_trunc_no_error(Matrix(Hermitian(A)); kwargs...) + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + test_rrule( + cr_copy_eigh_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_eigh_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔDVtrunc, atol = atol, rtol = rtol + ) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = truncrank(r; by = real) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), trunc) + truncalg = TruncatedAlgorithm(alg, trunc) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + test_rrule( + config, eigh_trunc2, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eigh_trunc_no_error2, A; + fkwargs = (; trunc = trunc), + output_tangent = ΔDVtrunc, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + D, ΔD = ad_eigh_vals_setup(A / 2) + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + test_rrule( + cr_copy_eigh_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_eigh_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔDVtrunc, atol = atol, rtol = rtol + ) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = trunctol(; rtol = 1 / 2) + truncalg = TruncatedAlgorithm(alg, trunc) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + test_rrule( + config, eigh_trunc2, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eigh_trunc_no_error2, A; + fkwargs = (; trunc = trunc), + output_tangent = ΔDVtrunc, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_svd( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "SVD Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + minmn = min(size(A)...) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_svd_algorithm(A) + @testset "svd_compact" begin + USV, ΔUSVᴴ, ΔUS2Vᴴ = ad_svd_compact_setup(A) + test_rrule( + cr_copy_svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUS2Vᴴ, atol = atol, rtol = rtol + ) + test_rrule( + config, svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUS2Vᴴ, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "svd_vals" begin + S, ΔS = ad_svd_vals_setup(A) + test_rrule( + cr_copy_svd_vals, A, alg ⊢ NoTangent(); + output_tangent = ΔS, atol, rtol + ) + test_rrule( + config, svd_vals, A, alg ⊢ NoTangent(); + output_tangent = ΔS, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "svd_trunc" begin + @testset for r in 1:4:minmn + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + test_rrule( + cr_copy_svd_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔUSVᴴtrunc, + atol = atol, rtol = rtol + ) + U, S, Vᴴ = USVᴴ + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, USVᴴ, ΔUSVᴴtrunc, ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), ΔUSVᴴtrunc) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = truncrank(r) + ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) + test_rrule( + config, svd_trunc, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, svd_trunc_no_error, A; + fkwargs = (; trunc = trunc), + output_tangent = ΔUSVᴴtrunc, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + S, ΔS = ad_svd_vals_setup(A) + truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + test_rrule( + cr_copy_svd_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔUSVᴴtrunc, atol = atol, rtol = rtol + ) + U, S, Vᴴ = USVᴴ + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, USVᴴ, ΔUSVᴴtrunc, ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), ΔUSVᴴtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = trunctol(; atol = S[1, 1] / 2) + ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) + test_rrule( + config, svd_trunc, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, svd_trunc_no_error, A; + fkwargs = (; trunc = trunc), + output_tangent = ΔUSVᴴtrunc, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_polar( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Polar Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_polar_algorithm(A) + @testset "left_polar" begin + if m >= n + test_rrule(cr_copy_left_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) + test_rrule( + config, left_polar, A, alg ⊢ NoTangent(); + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end + @testset "right_polar" begin + if m <= n + test_rrule(cr_copy_right_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) + test_rrule( + config, right_polar, A, alg ⊢ NoTangent(); + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end + end +end + +function test_chainrules_orthnull( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Orthnull Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + N, ΔN = ad_left_null_setup(A) + Nᴴ, ΔNᴴ = ad_right_null_setup(A) + test_rrule( + config, left_orth, A; + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, left_orth, A; + fkwargs = (; alg = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m >= n && + test_rrule( + config, left_orth, A; + fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, left_null, A; + fkwargs = (; alg = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + + test_rrule( + config, right_orth, A; + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, right_orth, A; fkwargs = (; alg = :lq), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m <= n && + test_rrule( + config, right_orth, A; fkwargs = (; alg = :polar), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, right_null, A; + fkwargs = (; alg = :lq), output_tangent = ΔNᴴ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end +end diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl new file mode 100644 index 00000000..1fcfd66a --- /dev/null +++ b/test/testsuite/mooncake.jl @@ -0,0 +1,480 @@ +using TestExtras +using MatrixAlgebraKit +using Mooncake, Mooncake.TestUtils +using Mooncake: rrule!! +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc + +function mc_copy_eigh_full(A; kwargs...) + A = (A + A') / 2 + return eigh_full(A; kwargs...) +end + +function mc_copy_eigh_full!(A, DV; kwargs...) + A = (A + A') / 2 + return eigh_full!(A, DV; kwargs...) +end + +function mc_copy_eigh_vals(A; kwargs...) + A = (A + A') / 2 + return eigh_vals(A; kwargs...) +end + +function mc_copy_eigh_vals!(A, D; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D; kwargs...) +end + +function mc_copy_eigh_trunc(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc(A, alg; kwargs...) +end + +function mc_copy_eigh_trunc!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc!(A, DV, alg; kwargs...) +end + +function mc_copy_eigh_trunc_no_error(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error(A, alg; kwargs...) +end + +function mc_copy_eigh_trunc_no_error!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error!(A, DV, alg; kwargs...) +end + +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) + +make_mooncake_tangent(ΔAelem::T) where {T <: Real} = ΔAelem +make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) +make_mooncake_tangent(ΔA::AbstractMatrix{<:Real}) = ΔA +make_mooncake_tangent(ΔA::AbstractVector{<:Real}) = ΔA +make_mooncake_tangent(ΔA::AbstractMatrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) +make_mooncake_tangent(ΔA::AbstractVector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) +make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Real} = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) +make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_tangent(typeof(ΔD), map(make_mooncake_tangent, diagview(ΔD))) + +make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), make_mooncake_tangent.(T)...) + +make_mooncake_fdata(x) = make_mooncake_tangent(x) +make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) + +# no `alg` argument +function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) + dA_copy = make_mooncake_tangent(copy(ΔA)) + A_copy = copy(A) + dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy)) + copy_pb!!(rdata) + return dA_copy +end + +# `alg` argument +function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) + dA_copy = make_mooncake_tangent(copy(ΔA)) + A_copy = copy(A) + dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData())) + copy_pb!!(rdata) + return dA_copy +end + +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) + dA_inplace = make_mooncake_tangent(copy(ΔA)) + A_inplace = copy(A) + dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + # not every f! has a handwritten rrule!! + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} + has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + if has_handwritten_rule + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) + else + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) + end + inplace_pb!!(rdata) + return dA_inplace +end + +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) + dA_inplace = make_mooncake_tangent(copy(ΔA)) + A_inplace = copy(A) + dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + # not every f! has a handwritten rrule!! + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} + has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + if has_handwritten_rule + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) + else + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) + end + inplace_pb!!(rdata) + return dA_inplace +end + +""" + test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) + +Compare the result of running the *in-place, mutating* function `f!`'s reverse rule +with the result of running its *non-mutating* partner function `f`'s reverse rule. +We must compare directly because many of the mutating functions modify `A` as a +scratch workspace, making testing `f!` against finite differences infeasible. + +The arguments to this function are: + - `f!` the mutating, in-place version of the function (accepts `args` for the function result) + - `f` the non-mutating version of the function (does not accept `args` for the function result) + - `A` the input matrix to factorize + - `args` preallocated output for `f!` (e.g. `Q` and `R` matrices for `qr_compact!`) + - `Δargs` precomputed derivatives of `args` for pullbacks of `f` and `f!`, to ensure they receive the same input + - `alg` optional algorithm keyword argument + - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) +""" +function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) + f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg) + sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + rrule = Mooncake.build_rrule(rvs_interp, sig) + ΔA = randn!(similar(A)) + + dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) + dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) + + dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] + dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] + @test dA_inplace_ ≈ dA_copy_ + return +end + +function test_mooncake(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake AD $summary_str" begin + test_mooncake_qr(T, sz; kwargs...) + test_mooncake_lq(T, sz; kwargs...) + if length(sz) == 1 || sz[1] == sz[2] + test_mooncake_eig(T, sz; kwargs...) + test_mooncake_eigh(T, sz; kwargs...) + end + test_mooncake_svd(T, sz; kwargs...) + test_mooncake_polar(T, sz; kwargs...) + test_mooncake_orthnull(T, sz; kwargs...) + end +end + +function test_mooncake_qr( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "QR Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + @testset "qr_compact" begin + QR, ΔQR = ad_qr_compact_setup(A) + Mooncake.TestUtils.test_rule(rng, qr_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(qr_compact!, qr_compact, A, QR, ΔQR) + end + @testset "qr_null" begin + N, ΔN = ad_qr_null_setup(A) + dN = make_mooncake_tangent(copy(ΔN)) + Mooncake.TestUtils.test_rule(rng, qr_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dN, atol = atol, rtol = rtol) + test_pullbacks_match(qr_null!, qr_null, A, N, ΔN) + end + @testset "qr_full" begin + QR, ΔQR = ad_qr_full_setup(A) + dQR = make_mooncake_tangent(ΔQR) + Mooncake.TestUtils.test_rule(rng, qr_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) + test_pullbacks_match(qr_full!, qr_full, A, QR, ΔQR) + end + @testset "qr_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + QR, ΔQR = ad_qr_rd_compact_setup(Ard) + dQR = make_mooncake_tangent(ΔQR) + Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) + test_pullbacks_match(qr_compact!, qr_compact, Ard, QR, ΔQR) + end + end +end + +function test_mooncake_lq( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "LQ Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + @testset "lq_compact" begin + LQ, ΔLQ = ad_lq_compact_setup(A) + Mooncake.TestUtils.test_rule(rng, lq_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(lq_compact!, lq_compact, A, LQ, ΔLQ) + end + @testset "lq_null" begin + Nᴴ, ΔNᴴ = ad_lq_null_setup(A) + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, lq_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(lq_null!, lq_null, A, Nᴴ, ΔNᴴ) + end + @testset "lq_full" begin + LQ, ΔLQ = ad_lq_full_setup(A) + dLQ = make_mooncake_tangent(ΔLQ) + Mooncake.TestUtils.test_rule(rng, lq_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) + test_pullbacks_match(lq_full!, lq_full, A, LQ, ΔLQ) + end + @testset "lq_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + LQ, ΔLQ = ad_lq_rd_compact_setup(Ard) + dLQ = make_mooncake_tangent(ΔLQ) + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) + test_pullbacks_match(lq_compact!, lq_compact, Ard, LQ, ΔLQ) + end + end +end + +function test_mooncake_eig( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIG Mooncake AD rules $summary_str" begin + A = make_eig_matrix(T, sz) + m = size(A, 1) + @testset "eig_full" begin + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) + dDV = make_mooncake_tangent(ΔD2V) + Mooncake.TestUtils.test_rule(rng, eig_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dDV, atol = atol, rtol = rtol) + test_pullbacks_match(eig_full!, eig_full, A, DV, ΔD2V) + end + @testset "eig_vals" begin + D, ΔD = ad_eig_vals_setup(A) + dD = make_mooncake_tangent(ΔD) + Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(eig_vals!, eig_vals, A, D, ΔD) + end + if T <: Number # not a GPU array + @testset "eig_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs)) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + ϵ = zero(real(T)) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol) + test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol) + test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg) + end + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + ϵ = zero(real(T)) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol) + test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol) + test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg) + end + end + end +end + +function test_mooncake_eigh( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIGH Mooncake AD rules $summary_str" begin + A = make_eigh_matrix(T, sz) + m = size(A, 1) + @testset "eigh_full" begin + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) + dDV = make_mooncake_tangent(ΔD2V) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(mc_copy_eigh_full!, mc_copy_eigh_full, A, DV, ΔD2V) + end + @testset "eigh_vals" begin + D, ΔD = ad_eigh_vals_setup(A) + dD = make_mooncake_tangent(ΔD) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(mc_copy_eigh_vals!, mc_copy_eigh_vals, A, D, ΔD) + end + if T <: Number + @testset "eigh_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(r; by = abs)) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ϵ = zero(real(T)) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg) + end + D = eigh_vals(A / 2) + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2)) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ϵ = zero(real(T)) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg) + end + end + end +end + +function test_mooncake_svd( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "SVD Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + minmn = min(size(A)...) + @testset "svd_compact" begin + USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ) + end + @testset "svd_full" begin + USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ) + end + @testset "svd_vals" begin + S, ΔS = ad_svd_vals_setup(A) + Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(svd_vals!, svd_vals, A, S, ΔS) + end + if T <: Number # not a GPU array + @testset "svd_trunc" begin + S, ΔS = ad_svd_vals_setup(A) + @testset for r in 1:4:minmn + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + ϵ = zero(real(T)) + dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg) + end + @testset "trunctol" begin + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + ϵ = zero(real(T)) + dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg) + end + end + end + end +end + +function test_mooncake_polar( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Polar Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + @testset "left_polar" begin + if m >= n + WP, ΔWP = ad_left_polar_setup(A) + Mooncake.TestUtils.test_rule(rng, left_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(left_polar!, left_polar, A, WP, ΔWP) + end + end + @testset "right_polar" begin + if m <= n + PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) + Mooncake.TestUtils.test_rule(rng, right_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(right_polar!, right_polar, A, PWᴴ, ΔPWᴴ) + end + end + end +end + +left_orth_qr(X) = left_orth(X; alg = :qr) +left_orth_polar(X) = left_orth(X; alg = :polar) +left_null_qr(X) = left_null(X; alg = :qr) +right_orth_lq(X) = right_orth(X; alg = :lq) +right_orth_polar(X) = right_orth(X; alg = :polar) +right_null_lq(X) = right_null(X; alg = :lq) + +MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) +MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) +MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) +MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) + +function test_mooncake_orthnull( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Orthnull Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + VC, ΔVC = ad_left_orth_setup(A) + CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) + Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(left_orth!, left_orth, A, VC, ΔVC) + Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(right_orth!, right_orth, A, CVᴴ, ΔCVᴴ) + + Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, ΔVC) + if m >= n + Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, ΔVC) + end + + N, ΔN = ad_left_null_setup(A) + dN = make_mooncake_tangent(ΔN) + Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) + test_pullbacks_match(((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) + + Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, ΔCVᴴ) + + if m <= n + Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, ΔCVᴴ) + end + + Nᴴ, ΔNᴴ = ad_right_null_setup(A) + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) + test_pullbacks_match(((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) + end +end From cf2b8ddc74c86380d7c98eba4c2ed7566ada07e9 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 14:42:17 -0500 Subject: [PATCH 2/3] Fix ChainRules --- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index c2de1758..400b2a79 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -95,6 +95,9 @@ for eig in (:eig, :eigh) eig_t! = Symbol(eig, "_trunc!") eig_t_pb = Symbol(eig, "_trunc_pullback") _make_eig_t_pb = Symbol("_make_", eig_t_pb) + eig_t_ne! = Symbol(eig, "_trunc_no_error!") + eig_t_ne_pb = Symbol(eig, "_trunc_no_error_pullback") + _make_eig_t_ne_pb = Symbol("_make_", eig_t_ne_pb) eig_v = Symbol(eig, "_vals") eig_v! = Symbol(eig_v, "!") eig_v_pb = Symbol(eig_v, "_pullback") @@ -136,6 +139,24 @@ for eig in (:eig, :eigh) end return $eig_t_pb end + function ChainRulesCore.rrule(::typeof($eig_t_ne!), A, DV, alg::TruncatedAlgorithm) + Ac = copy_input($eig_f, A) + DV = $(eig_f!)(Ac, DV, alg.alg) + DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc) + return DV′, $(_make_eig_t_ne_pb)(A, DV, ind) + end + function $(_make_eig_t_ne_pb)(A, DV, ind) + function $eig_t_ne_pb(ΔDV) + ΔA = zero(A) + ΔD, ΔV = ΔDV + MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function $eig_t_ne_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful? + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return $eig_t_ne_pb + end function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg) DV = $eig_f(A, alg) function $eig_v_pb(ΔD) From d3a1583eec85fd87647df015edfdf644c158df46 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 16:39:47 -0500 Subject: [PATCH 3/3] Fix LQ? --- test/testsuite/ad_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 4d4108b9..aa565894 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -138,7 +138,7 @@ function ad_lq_full_setup(A) Q1 = view(Q, 1:minmn, 1:n) ΔQ = randn!(similar(A, T, n, n)) ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) - mul!(ΔQ2, ΔQ2 * Q1', Q1) + ΔQ2 .= (ΔQ2 * Q1') * Q1 ΔL = randn!(similar(A, T, m, n)) return (L, Q), (ΔL, ΔQ) end