122122 end
123123end
124124
125- @noinline function constant (
126- x:: AbstractArray{T,N} ; location= mlir_stacktrace (" constant" , @__FILE__ , @__LINE__ )
127- ) where {T,N}
128- return constant (collect (x); location)
129- end
130-
131- @noinline function constant (x:: Reactant.AbstractConcreteArray ; kwargs... )
132- return constant (Base. convert (Array, x); kwargs... )
133- end
134-
135125@noinline function constant (
136126 x:: T ; location= mlir_stacktrace (" constant" , @__FILE__ , @__LINE__ )
137127) where {T<: Number }
140130 return TracedRNumber {T} ((), res. mlir_data)
141131end
142132
143- @noinline function constant (x:: Reactant.AbstractConcreteNumber{T} ; kwargs... ) where {T}
144- return constant (Base. convert (T, x); kwargs... )
145- end
146-
147133function fill (
148134 v, dims:: Base.DimOrInd... ; location= mlir_stacktrace (" fill" , @__FILE__ , @__LINE__ )
149135)
391377end
392378
393379# shape ops
394- function reshape (x:: TracedRArray , dims:: Integer ... ; kwargs... )
380+ function reshape (x:: TracedRArray , dims... ; kwargs... )
395381 return reshape (x, collect (dims); kwargs... )
396382end
397383
@@ -2394,7 +2380,7 @@ end
23942380 x::TracedRArray{T},
23952381 init_values::TracedRNumber{T},
23962382 dimensions::Vector{Int},
2397- fn::Function;
2383+ fn::Function,
23982384 location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
23992385 )
24002386
@@ -2426,43 +2412,25 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
24262412 - **CPU version & Julia's `reduce`**:
24272413 - Reduce along dimension 1 → `[(15) (21); (18) (24)]`
24282414 - Reduce along dimension 3 → `[(33 + 2) (45 + 2)]` → `[35 47]`
2429-
2415+
24302416 - **GPU version**:
24312417 - Reduce along dimension 1 → `[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]`
24322418 - Reduce along dimension 3 → `[37 49]`
24332419"""
24342420@noinline function reduce (
24352421 x:: TracedRArray{T} ,
2436- init_values:: Union{ TracedRNumber{T},Nothing } ,
2422+ init_values:: TracedRNumber{T} ,
24372423 dimensions:: Vector{Int} ,
2438- fn:: Function ;
2424+ fn:: Function ,
24392425 location= mlir_stacktrace (" reduce" , @__FILE__ , @__LINE__ ),
24402426) where {T}
2441- elT = T
2442- if init_values === nothing
2443- if fn === min || fn === Base. FastMath. min_fast
2444- init = typemax (elT)
2445- elseif fn === max || fn === Base. FastMath. max_fast
2446- init = typemin (elT)
2447- else
2448- init = Base. reduce_empty (Base. BottomRF (fn), elT)
2449- end
2450-
2451- initT = unwrapped_eltype (typeof (init))
2452- if initT != elT # Bool, etc. reductions
2453- elT = promote_type (initT, elT)
2454- x = elT .(x)
2455- end
2456- init_values = Reactant. TracedUtils. promote_to (TracedRNumber{elT}, init)
2457- end
2458-
24592427 reduced_shape = Tuple (deleteat! (collect (size (x)), dimensions))
24602428
2461- result_type = mlir_type (TracedRArray{elT ,length (reduced_shape)}, reduced_shape)
2429+ result_type = mlir_type (TracedRArray{T ,length (reduced_shape)}, reduced_shape)
24622430
24632431 sample_inputs = [
2464- Reactant. TracedUtils. promote_to (TracedRNumber{elT }, 0 ),
2465- Reactant. TracedUtils. promote_to (TracedRNumber{elT }, 0 ),
2432+ Reactant. TracedUtils. promote_to (TracedRNumber{T }, 0 ),
2433+ Reactant. TracedUtils. promote_to (TracedRNumber{T }, 0 ),
24662434 ]
24672435
24682436 func =
@@ -2476,8 +2444,14 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
24762444 return_dialect= :stablehlo ,
24772445 ). f
24782446 @assert MLIR. IR. nregions (func) == 1
2479- ftype = MLIR. IR. Type (MLIR. IR. attr (func, " function_type" ))
2480- @assert MLIR. IR. result (ftype) == MLIR. IR. TensorType ((), MLIR. IR. Type (elT)) " $fn return type is not tensor<i1>"
2447+ fn_name = String (
2448+ MLIR. IR. attr (func, String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ()))
2449+ )
2450+ ftype_attr = MLIR. IR. attr (func, " function_type" )
2451+ ftype = MLIR. IR. Type (ftype_attr)
2452+ @assert MLIR. IR. result (ftype) == MLIR. IR. TensorType ((), MLIR. IR. Type (T)) error (
2453+ " $fn return type is not tensor<i1>"
2454+ )
24812455 fn = MLIR. IR. Region ()
24822456 MLIR. API. mlirRegionTakeBody (fn, MLIR. IR. region (func, 1 ))
24832457 MLIR. IR. rmfromparent! (func)
@@ -2495,7 +2469,7 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
24952469 ),
24962470 )
24972471
2498- return TracedRArray {elT ,length(reduced_shape)} ((), res, reduced_shape)
2472+ return TracedRArray {T ,length(reduced_shape)} ((), res, reduced_shape)
24992473end
25002474
25012475end # module Ops
0 commit comments