Skip to content

Commit c93a3aa

Browse files
committed
feat: more operation coverage
1 parent 77c5563 commit c93a3aa

File tree

5 files changed

+96
-3
lines changed

5 files changed

+96
-3
lines changed

src/TracedRArray.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ Base.elsize(::Type{TracedRArray{T,N}}) where {T,N} = sizeof(T)
2929
# we use it
3030
Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T, x)
3131

32+
# Base.first is very common usecase for getting first element to get the type
33+
# inside LinearAlgebra.jl
34+
Base.first(x::TracedRArray{T,N}) where {T,N} = @allowscalar(x[1])
35+
3236
# Base.complex
3337
Base.complex(x::TracedRArray{<:Real}) = complex.(x)
3438
Base.complex(x::TracedRArray{<:Complex}) = x

src/TracedRNumber.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,25 @@ Base.copy(x::TracedRNumber{T}) where {T} = TracedRNumber{T}((), x.mlir_data)
2626
function Base.eps(::Type{TracedRNumber{T}}) where {T}
2727
return Reactant.promote_to(TracedRNumber{T}, eps(T))
2828
end
29+
Base.eps(x::TracedRNumber{T}) where {T} = eps(typeof(x))
2930

3031
function Base.typemin(::Type{TracedRNumber{T}}) where {T}
3132
return Reactant.promote_to(TracedRNumber{T}, typemin(T))
3233
end
34+
Base.typemin(x::TracedRNumber{T}) where {T} = typemin(typeof(x))
35+
3336
function Base.typemax(::Type{TracedRNumber{T}}) where {T}
3437
return Reactant.promote_to(TracedRNumber{T}, typemax(T))
3538
end
39+
Base.typemax(x::TracedRNumber{T}) where {T} = typemax(typeof(x))
40+
41+
function Base.nextfloat(x::TracedRNumber{T}) where {T<:AbstractFloat}
42+
return @opcall next_after(x, typemax(x))
43+
end
44+
45+
function Base.prevfloat(x::TracedRNumber{T}) where {T<:AbstractFloat}
46+
return @opcall next_after(x, typemin(x))
47+
end
3648

3749
function Base.rtoldefault(T::Type{<:TracedRNumber})
3850
return T(Base.rtoldefault(unwrapped_eltype(T)))

src/stdlibs/LinearAlgebra.jl

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using ..Reactant: Reactant, Ops
55
using ..Reactant:
66
TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector
77
using ..Reactant: call_with_reactant
8-
using ReactantCore: ReactantCore, materialize_traced_array
8+
using ReactantCore: ReactantCore, materialize_traced_array, @trace
99
using Reactant_jll: Reactant_jll
1010

1111
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!
@@ -15,7 +15,7 @@ using LinearAlgebra: LinearAlgebra, BLAS
1515
using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum, NoPivot
1616
using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal
1717
using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular
18-
using LinearAlgebra: diag, diagm, ldiv!
18+
using LinearAlgebra: diag, diagm, ldiv!, inv, rmul!
1919
using Libdl: Libdl
2020
using GPUArraysCore: @allowscalar
2121

@@ -694,4 +694,44 @@ function LinearAlgebra.ishermitian(A::AnyTracedRMatrix)
694694
return all(A .== adjoint(A))
695695
end
696696

697+
function LinearAlgebra.isbanded(A::AnyTracedRMatrix, kl::Integer, ku::Integer)
698+
return LinearAlgebra.istriu(A, kl) && LinearAlgebra.istril(A, ku)
699+
end
700+
701+
@static if isdefined(LinearAlgebra, :__normalize!)
702+
function LinearAlgebra.__normalize!(a::AnyTracedRArray, nrm)
703+
# The largest positive floating point number whose inverse is less than infinity
704+
δ = inv(prevfloat(typemax(nrm)))
705+
@trace if nrm δ # Safe to multiply with inverse
706+
invnrm = inv(nrm)
707+
rmul!(a, invnrm)
708+
else # scale elements to avoid overflow
709+
εδ = eps(one(nrm)) / δ
710+
rmul!(a, εδ)
711+
rmul!(a, inv(nrm * εδ))
712+
end
713+
return a
714+
end
715+
end
716+
717+
function LinearAlgebra.rmul!(A::AnyTracedRArray, b::Number)
718+
@. A *= b
719+
return A
720+
end
721+
722+
function LinearAlgebra.lmul!(b::Number, A::AnyTracedRArray)
723+
@. A = b * A
724+
return A
725+
end
726+
727+
function LinearAlgebra.rdiv!(A::AnyTracedRArray, b::Number)
728+
@. A /= b
729+
return A
730+
end
731+
732+
function LinearAlgebra.ldiv!(b::Number, A::AnyTracedRArray)
733+
@. A = b \ A
734+
return A
735+
end
736+
697737
end

src/stdlibs/factorization/Cholesky.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ end
88
Base.size(c::GeneralizedCholesky) = size(c.factors)
99
Base.ndims(c::GeneralizedCholesky) = ndims(c.factors)
1010

11+
function overloaded_cholesky(A::AbstractArray, ::NoPivot; check::Bool=false)
12+
return overloaded_cholesky(Reactant.promote_to(TracedRArray, A), NoPivot(); check)
13+
end
14+
1115
function overloaded_cholesky(
12-
A::AbstractArray{T,N}, ::NoPivot; check::Bool=false
16+
A::AnyTracedRArray{T,N}, ::NoPivot; check::Bool=false
1317
) where {T,N}
1418
# TODO: dont ignore check
1519
# move the batching dims to the front

src/stdlibs/factorization/SVD.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,34 @@
1+
struct GeneralizedSVD{T,Tr,M<:AbstractArray{T},C<:AbstractArray{T}} <: Factorization{T}
2+
U::M
3+
S::C
4+
Vt::M
15

6+
function GeneralizedSVD{T,Tr,M,C}(U::M, S::C, Vt::M) where {T,Tr,M,C}
7+
@assert ndims(S) == ndims(U) - 1
8+
return new{T,Tr,M,C}(U, S, Vt)
9+
end
10+
end
11+
12+
function overloaded_svd(A::AbstractArray; kwargs...)
13+
return overloaded_svd(Reactant.promote_to(TracedRArray, A); kwargs...)
14+
end
15+
16+
function overloaded_svd(
17+
A::AnyTracedRArray{T,N}; full::Bool=false, algorithm=nothing
18+
) where {T,N}
19+
# TODO: don't ignore the algorithm kwarg
20+
return error("TODO: Not implemented yet")
21+
end
22+
23+
function overloaded_svd(
24+
A::AnyTracedRVector{T}; full::Bool=false, algorithm=nothing
25+
) where {T}
26+
# TODO: don't ignore the algorithm kwarg
27+
m = length(A)
28+
normA = LinearAlgebra.norm(A)
29+
30+
return error("TODO: Not implemented yet")
31+
end
32+
33+
# TODO: compute svdvals without computing the full svd. In principle we should
34+
# simple dce the U and Vt inside the compiler itself and simply compute Σ

0 commit comments

Comments
 (0)