1- struct GeneralizedSVD{T,Tr,M<: AbstractArray{T} ,C<: AbstractArray{T} } <: Factorization{T}
1+ struct GeneralizedSVD{T,Tr,M<: AbstractArray ,C<: AbstractArray } <: Factorization{T}
22 U:: M
33 S:: C
44 Vt:: M
@@ -9,6 +9,12 @@ struct GeneralizedSVD{T,Tr,M<:AbstractArray{T},C<:AbstractArray{T}} <: Factoriza
99 end
1010end
1111
12+ function GeneralizedSVD (
13+ U:: AbstractArray{T} , S:: AbstractArray{Tr} , Vt:: AbstractArray{T}
14+ ) where {T,Tr}
15+ return GeneralizedSVD {T,Tr,typeof(U),typeof(S)} (U, S, Vt)
16+ end
17+
1218function overloaded_svd (A:: AbstractArray ; kwargs... )
1319 return overloaded_svd (Reactant. promote_to (TracedRArray, A); kwargs... )
1420end
@@ -17,17 +23,55 @@ function overloaded_svd(
1723 A:: AnyTracedRArray{T,N} ; full:: Bool = false , algorithm= nothing
1824) where {T,N}
1925 # TODO : don't ignore the algorithm kwarg
20- return error (" TODO: Not implemented yet" )
26+ U, S, Vt = @opcall svd (A; full)
27+ return GeneralizedSVD (U, S, Vt)
2128end
2229
2330function overloaded_svd (
2431 A:: AnyTracedRVector{T} ; full:: Bool = false , algorithm= nothing
2532) where {T}
2633 # TODO : don't ignore the algorithm kwarg
27- m = length (A)
28- normA = LinearAlgebra. norm (A)
34+ normA = Reactant. call_with_reactant (LinearAlgebra. norm, A)
35+ U, S, Vt = if full
36+ ReactantCore. traced_if (
37+ iszero (normA), zeronorm_vector_svd_full, vector_svd_full, (A, normA)
38+ )
39+ else
40+ ReactantCore. traced_if (iszero (normA), zeronorm_vector_svd, vector_svd, (A, normA))
41+ end
42+ return GeneralizedSVD (U, S, Vt)
43+ end
44+
45+ function zeronorm_vector_svd (A:: AbstractVector{T} , normA) where {T}
46+ return zeronorm_vector_svd (A, false , normA)
47+ end
48+ function zeronorm_vector_svd_full (A:: AbstractVector{T} , normA) where {T}
49+ return zeronorm_vector_svd (A, true , normA)
50+ end
51+
52+ function zeronorm_vector_svd (A:: AbstractVector{T} , full:: Bool , normA) where {T}
53+ U = Reactant. promote_to (
54+ TracedRArray,
55+ Matrix {Reactant.unwrapped_eltype(T)} (
56+ LinearAlgebra. I, length (A), full ? length (A) : 1
57+ ),
58+ )
59+ return U, fill (normA, 1 ), ones (T, 1 , 1 )
60+ end
2961
30- return error (" TODO: Not implemented yet" )
62+ vector_svd (A:: AbstractVector{T} , normA) where {T} = vector_svd (A, false , normA)
63+ function vector_svd_full (A:: AbstractVector{T} , normA) where {T}
64+ return vector_svd (A, true , normA)
65+ end
66+
67+ function vector_svd (A:: AbstractVector{T} , full:: Bool , normA) where {T}
68+ if ! full
69+ U = materialize_traced_array (reshape (normalize (A), length (A), 1 ))
70+ return U, fill (normA, 1 ), ones (T, 1 , 1 )
71+ end
72+ return @opcall svd (
73+ materialize_traced_array (reshape (normalize (A), length (A), 1 )); full
74+ )
3175end
3276
3377# TODO : compute svdvals without computing the full svd. In principle we should
0 commit comments