@@ -94,7 +94,6 @@ for (jlop, hloop) in (
9494 (:(Base.:* ), :multiply ),
9595 (:(Base.:/ ), :divide ),
9696 (:(Base.:^ ), :power ),
97- (:(Base. mod), :remainder ),
9897 (:(Base. rem), :remainder ),
9998)
10099 @eval function $ (jlop)(
@@ -109,13 +108,30 @@ function Base.rem(
109108) where {T}
110109 return Ops. remainder (lhs, TracedUtils. promote_to (TracedRNumber{T}, rhs))
111110end
112-
113111function Base. rem (
114112 @nospecialize (lhs:: Number ), @nospecialize (rhs:: TracedRNumber{T} )
115113) where {T}
116114 return Ops. remainder (TracedUtils. promote_to (TracedRNumber{T}, lhs), rhs)
117115end
118116
117+ # Based on https://github.com/JuliaLang/julia/blob/39255d47db7657950ff1c82137ecec5a70bae622/base/float.jl#L608-L617
118+ function Base. mod (
119+ @nospecialize (x:: Reactant.TracedRNumber{T} ), @nospecialize (y:: Reactant.TracedRNumber{T} )
120+ ) where {T}
121+ r = rem (x, y)
122+ return ifelse (r == 0 , copysign (r, y), ifelse ((r > 0 ) ⊻ (y > 0 ), r + y, r))
123+ end
124+ function Base. mod (
125+ @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs:: Number )
126+ ) where {T}
127+ return mod (lhs, TracedUtils. promote_to (TracedRNumber{T}, rhs))
128+ end
129+ function Base. mod (
130+ @nospecialize (lhs:: Number ), @nospecialize (rhs:: TracedRNumber{T} )
131+ ) where {T}
132+ return mod (TracedUtils. promote_to (TracedRNumber{T}, lhs), rhs)
133+ end
134+
119135function Base. div (@nospecialize (lhs:: TracedRNumber{T} ), rhs) where {T<: Integer }
120136 return Ops. divide (lhs, TracedUtils. promote_to (TracedRNumber{T}, rhs))
121137end
@@ -224,6 +240,12 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
224240 TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
225241 )
226242 end
243+ function Base. xor (x:: TracedRNumber{<:$(T1)} , y:: TracedRNumber{<:$(T2)} )
244+ return Ops. xor (
245+ TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
246+ TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
247+ )
248+ end
227249 Base.:! (x:: TracedRNumber{<:$(T1)} ) = Ops. not (x)
228250 end
229251end
@@ -391,4 +413,20 @@ function Base.typed_hvncat(
391413 return Base. typed_hvncat (T, dims, row_first, xs... )
392414end
393415
416+ for (Ti, Tf) in ((Int16, Float16), (Int32, Float32), (Int64, Float64))
417+ @eval begin
418+ Base. signbit (x:: TracedRNumber{$(Ti)} ) = x < 0
419+ Base. signbit (x:: TracedRNumber{$(Tf)} ) = signbit (Ops. bitcast_convert ($ (Ti), x))
420+ end
421+ end
422+ Base. signbit (:: TracedRNumber{<:Unsigned} ) = ConcreteRNumber (false )
423+
424+ Base. copysign (x:: TracedRNumber , y:: TracedRNumber ) = ifelse (signbit (y), - 1 , 1 ) * abs (x)
425+ function Base. copysign (x:: TracedRNumber{T} , y:: S ) where {T,S<: Number }
426+ return copysign (x, TracedUtils. promote_to (TracedRNumber{S}, y))
394427end
428+ function Base. copysign (x:: S , y:: TracedRNumber{T} ) where {S<: Number ,T}
429+ return copysign (TracedUtils. promote_to (TracedRNumber{S}, x), y)
430+ end
431+
432+ end # module TracedRNumberOverrides
0 commit comments