Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]

[sources]
CUDA = {url="https://github.com/JuliaGPU/CUDA.jl", rev="master"}
21 changes: 20 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MatrixAlgebraKitCUDAExt
using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
Expand Down Expand Up @@ -183,4 +183,23 @@ function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
return A, B
end

MatrixAlgebraKit.default_pullback_rank_atol(A::AnyCuArray) = eps(norm(CuArray(A), Inf))^(3 / 4)
MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray) = MatrixAlgebraKit.iszerotangent(A) ? 0 : eps(norm(CuArray(A), Inf))^(3 / 4)
function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...)
As′ = filter(!MatrixAlgebraKit.iszerotangent, (A, As...))
return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4)
end

function LinearAlgebra.sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
#=m = size(A, 1)
n = size(B, 2)
I_n = fill!(similar(A, n), one(eltype(A)))
I_m = fill!(similar(B, m), one(eltype(B)))
L = kron(diagm(I_n), A) + kron(adjoint(B), diagm(I_m))
x_vec = L \ -vec(C)
X = CuMatrix(reshape(x_vec, m, n))=#
hX = sylvester(collect(A), collect(B), collect(C))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very awful but I wasn't able to find a correct way to do it in five minutes so there you go

return CuArray(hX)
end

end
21 changes: 21 additions & 0 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ for eig in (:eig, :eigh)
eig_t! = Symbol(eig, "_trunc!")
eig_t_pb = Symbol(eig, "_trunc_pullback")
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
eig_t_ne! = Symbol(eig, "_trunc_no_error!")
eig_t_ne_pb = Symbol(eig, "_trunc_no_error_pullback")
_make_eig_t_ne_pb = Symbol("_make_", eig_t_ne_pb)
eig_v = Symbol(eig, "_vals")
eig_v! = Symbol(eig_v, "!")
eig_v_pb = Symbol(eig_v, "_pullback")
Expand Down Expand Up @@ -136,6 +139,24 @@ for eig in (:eig, :eigh)
end
return $eig_t_pb
end
function ChainRulesCore.rrule(::typeof($eig_t_ne!), A, DV, alg::TruncatedAlgorithm)
Ac = copy_input($eig_f, A)
DV = $(eig_f!)(Ac, DV, alg.alg)
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
return DV′, $(_make_eig_t_ne_pb)(A, DV, ind)
end
function $(_make_eig_t_ne_pb)(A, DV, ind)
function $eig_t_ne_pb(ΔDV)
ΔA = zero(A)
ΔD, ΔV = ΔDV
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function $eig_t_ne_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return $eig_t_ne_pb
end
function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg)
DV = $eig_f(A, alg)
function $eig_v_pb(ΔD)
Expand Down
13 changes: 12 additions & 1 deletion ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
using Mooncake
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
using MatrixAlgebraKit
using MatrixAlgebraKit: inv_safe, diagview, copy_input
using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
Expand All @@ -26,6 +26,17 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu
return CoDual(Ac, dAc), copy_input_pb
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(initialize_output), Any, Any, Any}
function Mooncake.rrule!!(::CoDual{typeof(initialize_output)}, f_df::CoDual, A_dA::CoDual, alg_dalg::CoDual)
output = initialize_output(Mooncake.primal(f_df), Mooncake.primal(A_dA), Mooncake.primal(alg_dalg))
doutput = Mooncake.zero_tangent(output)
function initialize_output_pb(::NoRData)
return NoRData(), NoRData(), NoRData(), NoRData()
end
return CoDual(output, doutput), initialize_output_pb
end


# two-argument in-place factorizations like LQ, QR, EIG
for (f!, f, pb, adj) in (
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),
Expand Down
9 changes: 6 additions & 3 deletions src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ function eig_pullback!(
mul!(view(VᴴΔV, :, indV), V', ΔV)

mask = abs.(transpose(D) .- D) .< degeneracy_atol
Δgauge = norm(view(VᴴΔV, mask), Inf)
Δgauge ≤ gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
if isa(ΔA, Array)
# not GPU friendly...
Δgauge = norm(view(VᴴΔV, mask), Inf)
Δgauge ≤ gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end

VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))

Expand Down
46 changes: 26 additions & 20 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,24 @@ function lq_pullback!(
ΔA1 = view(ΔA, 1:p, :)
ΔA2 = view(ΔA, (p + 1):m, :)

if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
Δgauge = max(Δgauge, norm(ΔL22, Inf))
if isa(ΔA, Array) # not GPU friendly
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
Δgauge = max(Δgauge, norm(ΔL22, Inf))
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end

ΔQ̃ = zero!(similar(Q, (p, n)))
Expand All @@ -69,9 +71,11 @@ function lq_pullback!(
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
ΔQ2Q1ᴴ = ΔQ2 * Q1'
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
if isa(ΔA, Array) # not GPU friendly
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
end
end
Expand All @@ -95,8 +99,10 @@ function lq_pullback!(
Md = diagview(M)
Md .= real.(Md)
end
ldiv!(LowerTriangular(L11)', M)
ldiv!(LowerTriangular(L11)', ΔQ̃)
# not GPU friendly...
L11arr = typeof(L)(L11)
ldiv!(LowerTriangular(L11arr)', M)
ldiv!(LowerTriangular(L11arr)', ΔQ̃)
ΔA1 = mul!(ΔA1, M, Q1, +1, 1)
ΔA1 .+= ΔQ̃
return ΔA
Expand Down
46 changes: 26 additions & 20 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,24 @@ function qr_pullback!(
ΔA1 = view(ΔA, :, 1:p)
ΔA2 = view(ΔA, :, (p + 1):n)

if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n)
Δgauge = max(Δgauge, norm(ΔR22, Inf))
if isa(ΔA, Array) # not GPU friendly
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n)
Δgauge = max(Δgauge, norm(ΔR22, Inf))
end
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end

ΔQ̃ = zero!(similar(Q, (m, p)))
Expand All @@ -69,9 +71,11 @@ function qr_pullback!(
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
Q1dΔQ2 = Q1' * ΔQ2
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
if isa(ΔA, Array) # not GPU friendly
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
end
end
Expand All @@ -95,8 +99,10 @@ function qr_pullback!(
Md = diagview(M)
Md .= real.(Md)
end
rdiv!(M, UpperTriangular(R11)')
rdiv!(ΔQ̃, UpperTriangular(R11)')
# not GPU-friendly...
R11arr = typeof(R)(R11)
rdiv!(M, UpperTriangular(R11arr)')
rdiv!(ΔQ̃, UpperTriangular(R11arr)')
ΔA1 = mul!(ΔA1, Q1, M, +1, 1)
ΔA1 .+= ΔQ̃
return ΔA
Expand Down
14 changes: 8 additions & 6 deletions src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol
"""
function svd_pullback!(
ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon();
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
rank_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])),
degeneracy_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])),
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
)
# Extract the SVD components
Expand All @@ -33,7 +33,7 @@ function svd_pullback!(
minmn = min(m, n)
S = diagview(Smat)
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
r = searchsortedlast(S, rank_atol; rev = true) # rank
r = findlast(s -> s ≥ rank_atol, S) # rank
Ur = view(U, :, 1:r)
Vᴴr = view(Vᴴ, 1:r, :)
Sr = view(S, 1:r)
Expand Down Expand Up @@ -71,9 +71,11 @@ function svd_pullback!(

# check whether cotangents arise from gauge-invariance objective function
mask = abs.(Sr' .- Sr) .< degeneracy_atol
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
Δgauge ≤ gauge_atol ||
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
if isa(ΔA, Array) # norm check not GPU friendly
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
Δgauge ≤ gauge_atol ||
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end

UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(Sr' .- Sr, degeneracy_atol) .+
(aUΔU .- aVΔV) .* inv_safe.(Sr' .+ Sr, degeneracy_atol)
Expand Down
62 changes: 0 additions & 62 deletions test/ad_utils.jl

This file was deleted.

Loading