Skip to content

Commit 9cc4e6a

Browse files
committed
feat: lower svdvals
1 parent 08e4ccf commit 9cc4e6a

File tree

1 file changed

+5
-3
lines changed
  • src/stdlibs/factorization

1 file changed

+5
-3
lines changed

src/stdlibs/factorization/SVD.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,12 @@ function vector_svd(A::AbstractVector{T}, full::Bool, normA) where {T}
6969
U = materialize_traced_array(reshape(normalize(A), length(A), 1))
7070
return U, fill(normA, 1), ones(T, 1, 1)
7171
end
72-
return @opcall svd(
73-
materialize_traced_array(reshape(normalize(A), length(A), 1)); full
74-
)
72+
return @opcall svd(materialize_traced_array(reshape(normalize(A), length(A), 1)); full)
7573
end
7674

7775
# TODO: compute svdvals without computing the full svd. In principle we should
7876
# simple dce the U and Vt inside the compiler itself and simply compute Σ
77+
LinearAlgebra.svdvals(x::AnyTracedRArray{T,N}) where {T,N} = overloaded_svd(x).S
78+
LinearAlgebra.svdvals!(x::AnyTracedRArray{T,N}) where {T,N} = overloaded_svd(x).S
79+
LinearAlgebra.svdvals(x::AnyTracedRVector{T}) where {T} = overloaded_svd(x).S
80+
LinearAlgebra.svdvals!(x::AnyTracedRVector{T}) where {T} = overloaded_svd(x).S

0 commit comments

Comments
 (0)