@@ -5,7 +5,7 @@ using ..Reactant: Reactant, Ops
55using .. Reactant:
66 TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector
77using .. Reactant: call_with_reactant
8- using ReactantCore: ReactantCore, materialize_traced_array
8+ using ReactantCore: ReactantCore, materialize_traced_array, @trace
99using Reactant_jll: Reactant_jll
1010
1111using .. TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!
@@ -15,7 +15,7 @@ using LinearAlgebra: LinearAlgebra, BLAS
1515using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum, NoPivot
1616using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal
1717using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular
18- using LinearAlgebra: diag, diagm, ldiv!
18+ using LinearAlgebra: diag, diagm, ldiv!, inv, rmul!
1919using Libdl: Libdl
2020using GPUArraysCore: @allowscalar
2121
@@ -694,4 +694,44 @@ function LinearAlgebra.ishermitian(A::AnyTracedRMatrix)
694694 return all (A .== adjoint (A))
695695end
696696
697+ function LinearAlgebra. isbanded (A:: AnyTracedRMatrix , kl:: Integer , ku:: Integer )
698+ return LinearAlgebra. istriu (A, kl) && LinearAlgebra. istril (A, ku)
699+ end
700+
701+ @static if isdefined (LinearAlgebra, :__normalize! )
702+ function LinearAlgebra. __normalize! (a:: AnyTracedRArray , nrm)
703+ # The largest positive floating point number whose inverse is less than infinity
704+ δ = inv (prevfloat (typemax (nrm)))
705+ @trace if nrm ≥ δ # Safe to multiply with inverse
706+ invnrm = inv (nrm)
707+ rmul! (a, invnrm)
708+ else # scale elements to avoid overflow
709+ εδ = eps (one (nrm)) / δ
710+ rmul! (a, εδ)
711+ rmul! (a, inv (nrm * εδ))
712+ end
713+ return a
714+ end
715+ end
716+
717+ function LinearAlgebra. rmul! (A:: AnyTracedRArray , b:: Number )
718+ @. A *= b
719+ return A
720+ end
721+
722+ function LinearAlgebra. lmul! (b:: Number , A:: AnyTracedRArray )
723+ @. A = b * A
724+ return A
725+ end
726+
727+ function LinearAlgebra. rdiv! (A:: AnyTracedRArray , b:: Number )
728+ @. A /= b
729+ return A
730+ end
731+
732+ function LinearAlgebra. ldiv! (b:: Number , A:: AnyTracedRArray )
733+ @. A = b \ A
734+ return A
735+ end
736+
697737end
0 commit comments