Skip to content

Commit 08e4ccf

Browse files
committed
feat: lower svd
1 parent c93a3aa commit 08e4ccf

File tree

3 files changed

+84
-7
lines changed

3 files changed

+84
-7
lines changed

src/Ops.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3312,6 +3312,39 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors `
33123312
return (res, ipiv, perm, info)
33133313
end
33143314

3315+
@noinline function svd(
3316+
x::TracedRArray{T,N};
3317+
full::Bool=false,
3318+
location=mlir_stacktrace("svd", @__FILE__, @__LINE__),
3319+
) where {T,N}
3320+
@assert N >= 2
3321+
3322+
batch_sizes = size(x)[1:(end - 2)]
3323+
m, n = size(x)[(end - 1):end]
3324+
r = min(m, n)
3325+
3326+
U_size = (batch_sizes..., m, full ? m : r)
3327+
S_size = (batch_sizes..., r)
3328+
Vt_size = (batch_sizes..., full ? n : r, n)
3329+
info_size = batch_sizes
3330+
3331+
svd_op = enzymexla.linalg_svd(
3332+
x.mlir_data;
3333+
U=mlir_type(TracedRArray{T,N}, U_size),
3334+
S=mlir_type(TracedRArray{T,N - 1}, S_size),
3335+
Vt=mlir_type(TracedRArray{T,N}, Vt_size),
3336+
info=mlir_type(TracedRArray{Int32,N - 2}, info_size),
3337+
full=full,
3338+
location,
3339+
)
3340+
3341+
U = TracedRArray{T,N}((), MLIR.IR.result(svd_op, 1), U_size)
3342+
S = TracedRArray{T,N - 1}((), MLIR.IR.result(svd_op, 2), S_size)
3343+
Vt = TracedRArray{T,N}((), MLIR.IR.result(svd_op, 3), Vt_size)
3344+
3345+
return U, S, Vt
3346+
end
3347+
33153348
@noinline function reduce_window(
33163349
f::F,
33173350
inputs::Vector{TracedRArray{T,N}},

src/stdlibs/LinearAlgebra.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using LinearAlgebra: LinearAlgebra, BLAS
1515
using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum, NoPivot
1616
using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal
1717
using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular
18-
using LinearAlgebra: diag, diagm, ldiv!, inv, rmul!
18+
using LinearAlgebra: diag, diagm, ldiv!, rmul!, normalize, norm
1919
using Libdl: Libdl
2020
using GPUArraysCore: @allowscalar
2121

@@ -330,7 +330,7 @@ end
330330
# LinearAlgebra defines norm with some conditionals which cannot be traced directly
331331
function LinearAlgebra.norm(x::TracedRArray{T,N}, p::Real=2) where {T,N}
332332
isinf(p) && return maximum(abs, x)
333-
return mapreduce(Base.Fix2(^, p), +, x)^(1 / p)
333+
return mapreduce(Base.Fix2(^, p), +, x)^(T(1 / p))
334334
end
335335

336336
function LinearAlgebra._diagm(shape, kv::Pair{<:Integer,<:AnyTracedRVector}...)

src/stdlibs/factorization/SVD.jl

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct GeneralizedSVD{T,Tr,M<:AbstractArray{T},C<:AbstractArray{T}} <: Factorization{T}
1+
struct GeneralizedSVD{T,Tr,M<:AbstractArray,C<:AbstractArray} <: Factorization{T}
22
U::M
33
S::C
44
Vt::M
@@ -9,6 +9,12 @@ struct GeneralizedSVD{T,Tr,M<:AbstractArray{T},C<:AbstractArray{T}} <: Factoriza
99
end
1010
end
1111

12+
function GeneralizedSVD(
13+
U::AbstractArray{T}, S::AbstractArray{Tr}, Vt::AbstractArray{T}
14+
) where {T,Tr}
15+
return GeneralizedSVD{T,Tr,typeof(U),typeof(S)}(U, S, Vt)
16+
end
17+
1218
function overloaded_svd(A::AbstractArray; kwargs...)
1319
return overloaded_svd(Reactant.promote_to(TracedRArray, A); kwargs...)
1420
end
@@ -17,17 +23,55 @@ function overloaded_svd(
1723
A::AnyTracedRArray{T,N}; full::Bool=false, algorithm=nothing
1824
) where {T,N}
1925
# TODO: don't ignore the algorithm kwarg
20-
return error("TODO: Not implemented yet")
26+
U, S, Vt = @opcall svd(A; full)
27+
return GeneralizedSVD(U, S, Vt)
2128
end
2229

2330
function overloaded_svd(
2431
A::AnyTracedRVector{T}; full::Bool=false, algorithm=nothing
2532
) where {T}
2633
# TODO: don't ignore the algorithm kwarg
27-
m = length(A)
28-
normA = LinearAlgebra.norm(A)
34+
normA = Reactant.call_with_reactant(LinearAlgebra.norm, A)
35+
U, S, Vt = if full
36+
ReactantCore.traced_if(
37+
iszero(normA), zeronorm_vector_svd_full, vector_svd_full, (A, normA)
38+
)
39+
else
40+
ReactantCore.traced_if(iszero(normA), zeronorm_vector_svd, vector_svd, (A, normA))
41+
end
42+
return GeneralizedSVD(U, S, Vt)
43+
end
44+
45+
function zeronorm_vector_svd(A::AbstractVector{T}, normA) where {T}
46+
return zeronorm_vector_svd(A, false, normA)
47+
end
48+
function zeronorm_vector_svd_full(A::AbstractVector{T}, normA) where {T}
49+
return zeronorm_vector_svd(A, true, normA)
50+
end
51+
52+
function zeronorm_vector_svd(A::AbstractVector{T}, full::Bool, normA) where {T}
53+
U = Reactant.promote_to(
54+
TracedRArray,
55+
Matrix{Reactant.unwrapped_eltype(T)}(
56+
LinearAlgebra.I, length(A), full ? length(A) : 1
57+
),
58+
)
59+
return U, fill(normA, 1), ones(T, 1, 1)
60+
end
2961

30-
return error("TODO: Not implemented yet")
62+
vector_svd(A::AbstractVector{T}, normA) where {T} = vector_svd(A, false, normA)
63+
function vector_svd_full(A::AbstractVector{T}, normA) where {T}
64+
return vector_svd(A, true, normA)
65+
end
66+
67+
function vector_svd(A::AbstractVector{T}, full::Bool, normA) where {T}
68+
if !full
69+
U = materialize_traced_array(reshape(normalize(A), length(A), 1))
70+
return U, fill(normA, 1), ones(T, 1, 1)
71+
end
72+
return @opcall svd(
73+
materialize_traced_array(reshape(normalize(A), length(A), 1)); full
74+
)
3175
end
3276

3377
# TODO: compute svdvals without computing the full svd. In principle we should

0 commit comments

Comments
 (0)