11module ReactantCUDAExt
22
3- using CUDA
43using Reactant: Reactant, TracedRArray, AnyConcretePJRTArray, MLIR, TracedRNumber
54using Reactant. Compiler: raising
6- using ReactantCore: @trace
5+ using Reactant. Ops: @opcall
6+
7+ using Adapt: Adapt, adapt
8+ using CUDA: CUDA, CuDim, DenseCuArray, unsafe_cached_load
9+
710using GPUCompiler: GPUCompiler
811using KernelAbstractions: KernelAbstractions
9- import KernelAbstractions as KA
1012using LLVM: LLVM
11- using Libdl
12-
13- using Reactant. Ops: @opcall
1413
15- const ReactantKernelAbstractionsExt = Base. get_extension (
16- Reactant, :ReactantKernelAbstractionsExt
17- )
18- const ReactantBackend = ReactantKernelAbstractionsExt. ReactantBackend
14+ using PrecompileTools: @setup_workload , @compile_workload
1915
20- using Adapt
16+ const KA = KernelAbstractions
2117
2218Reactant. is_extension_loaded (:: Val{:CUDA} ) = true
2319
@@ -64,9 +60,7 @@ function Base.getindex(RN::CuTracedRNumber{T,A}) where {T,A}
6460 return @inbounds unsafe_load (RN. ptr, 1 , Val (align))
6561end
6662
67- function Base. convert (:: Type{T} , RN:: CuTracedRNumber ) where {T<: Number }
68- return Base. convert (T, Base. getindex (RN))
69- end
63+ Base. convert (:: Type{T} , RN:: CuTracedRNumber ) where {T<: Number } = convert (T, getindex (RN))
7064
7165for jlop in (
7266 :(Base. min),
@@ -89,17 +83,15 @@ for jlop in (
8983 end
9084end
9185
92- @inline Base. ifelse (cond:: Bool , a, b:: CuTracedRNumber ) = Base . ifelse (cond, a, b[])
93- @inline Base. ifelse (cond:: Bool , a:: CuTracedRNumber , b) = Base . ifelse (cond, a[], b)
86+ @inline Base. ifelse (cond:: Bool , a, b:: CuTracedRNumber ) = ifelse (cond, a, b[])
87+ @inline Base. ifelse (cond:: Bool , a:: CuTracedRNumber , b) = ifelse (cond, a[], b)
9488@inline Base. ifelse (cond:: Bool , a:: CuTracedRNumber , b:: CuTracedRNumber ) =
95- Base. ifelse (cond, a[], b[])
96- @inline Base. ifelse (cond:: CuTracedRNumber , a, b) = Base. ifelse (cond[], a, b)
97- @inline Base. ifelse (cond:: CuTracedRNumber , a:: CuTracedRNumber , b) =
98- Base. ifelse (cond[], a[], b)
99- @inline Base. ifelse (cond:: CuTracedRNumber , a, b:: CuTracedRNumber ) =
100- Base. ifelse (cond[], a, b[])
89+ ifelse (cond, a[], b[])
90+ @inline Base. ifelse (cond:: CuTracedRNumber , a, b) = ifelse (cond[], a, b)
91+ @inline Base. ifelse (cond:: CuTracedRNumber , a:: CuTracedRNumber , b) = ifelse (cond[], a[], b)
92+ @inline Base. ifelse (cond:: CuTracedRNumber , a, b:: CuTracedRNumber ) = ifelse (cond[], a, b[])
10193@inline Base. ifelse (cond:: CuTracedRNumber , a:: CuTracedRNumber , b:: CuTracedRNumber ) =
102- Base . ifelse (cond[], a[], b[])
94+ ifelse (cond[], a[], b[])
10395
10496Base. @constprop :aggressive @inline Base.:^ (
10597 a:: CuTracedRNumber{T,A} , b:: Integer
140132 ),
141133 Core. LLVMPtr{UInt8,1 },
142134 Tuple{Float64},
143- Base . convert (Float64, x),
135+ convert (Float64, x),
144136 ),
145137 ),
146138 )
164156 ),
165157 Core. LLVMPtr{UInt8,1 },
166158 Tuple{Float32},
167- Base . convert (Float32, x),
159+ convert (Float32, x),
168160 ),
169161 ),
170162 )
@@ -181,7 +173,7 @@ Base.@nospecializeinfer function Base.promote_rule(
181173 @nospecialize (a:: Type{<:CuTracedRNumber{T}} ),
182174 @nospecialize (b:: Type{<:CuTracedRNumber{T2}} )
183175) where {T,T2}
184- return Base . promote_rule (T, T2)
176+ return promote_rule (T, T2)
185177end
186178Base. @nospecializeinfer function Base. promote_rule (
187179 :: Type{Any} , @nospecialize (b:: Type{<:CuTracedRNumber} )
@@ -199,7 +191,7 @@ Base.@nospecializeinfer function Base.promote_rule(
199191 if T == T2
200192 return T
201193 else
202- return Base . promote_rule (T, T2)
194+ return promote_rule (T, T2)
203195 end
204196end
205197Base. @nospecializeinfer function Base. promote_rule (
@@ -208,7 +200,7 @@ Base.@nospecializeinfer function Base.promote_rule(
208200 if T == T2
209201 return T
210202 else
211- return Base . promote_rule (T, T2)
203+ return promote_rule (T, T2)
212204 end
213205end
214206
@@ -506,9 +498,7 @@ function threads_to_workgroupsize(threads, ndrange)
506498 end
507499end
508500
509- function ReactantKernelAbstractionsExt. ka_with_reactant (
510- ndrange, workgroupsize, obj, args...
511- )
501+ function Reactant. ka_with_reactant (ndrange, workgroupsize, obj, args... )
512502 backend = KA. backend (obj)
513503
514504 ndrange, workgroupsize, iterspace, dynamic = KA. launch_config (
@@ -588,7 +578,7 @@ function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRNumber{T}) wher
588578 return res
589579end
590580
591- import Reactant. TracedRNumberOverrides . TracedStepRangeLen
581+ import Reactant. TracedStepRangeLen
592582
593583function Adapt. adapt_storage (:: ReactantKernelAdaptor , r:: TracedStepRangeLen )
594584 return TracedStepRangeLen (
@@ -1481,7 +1471,7 @@ end
14811471# In Julia v1.11.3 precompiling this module caches bad code:
14821472# <https://github.com/EnzymeAD/Reactant.jl/issues/614>.
14831473@static if ! Sys. isapple ()
1484- Reactant . PrecompileTools . @setup_workload begin
1474+ @setup_workload begin
14851475 Reactant. initialize_dialect ()
14861476
14871477 if Reactant. XLA. REACTANT_XLA_RUNTIME == " PJRT"
@@ -1492,7 +1482,7 @@ end
14921482 error (" Unsupported runtime: $(Reactant. XLA. REACTANT_XLA_RUNTIME) " )
14931483 end
14941484
1495- Reactant . PrecompileTools . @compile_workload begin
1485+ @compile_workload begin
14961486 @static if Reactant. precompilation_supported () && VERSION != v " 1.11.3"
14971487 function square_kernel! (x)
14981488 i = CUDA. threadIdx (). x
0 commit comments