@@ -468,100 +468,29 @@ function Base.mapreduce(
468468 dims= :,
469469 init= nothing ,
470470) where {T,N}
471- A = materialize_traced_array (A)
471+ inp = broadcast (f, materialize_traced_array (A) )
472472
473- if dims isa Int
474- dims = [dims]
475- end
476-
477- op_in_T = Core. Compiler. return_type (f, Tuple{T})
473+ dims isa Number && (dims = (dims,))
478474
479- if init === nothing
480- if op === min
481- init = typemax (op_in_T)
482- elseif op === max
483- init = typemin (op_in_T)
484- else
485- init = Base. reduce_empty (Base. BottomRF (op), op_in_T)
486- end
487-
488- if typeof (init) != op_in_T
489- op_in_T = typeof (init)
490- A = typeof (init).(A)
491- end
475+ if init != = nothing && typeof (init) != unwrapped_eltype (inp)
476+ inp = typeof (init).(inp)
492477 end
493478
494- init = [TracedUtils. broadcast_to_size (init, ()). mlir_data]
495-
496- inp = [broadcast (f, A). mlir_data]
479+ rdims = dims == (:) ? collect (Int64, 1 : N) : collect (Int64, dims)
497480
498- rdims = Int64[]
481+ reduction_result = Ops . reduce (inp, nothing , rdims, op)
499482
500- if dims == (:)
501- for i in 0 : (N - 1 )
502- push! (rdims, i)
503- end
483+ reduction_result = if dims != (:)
484+ Ops. reshape (reduction_result, Int64[i ∈ rdims ? 1 : size (A, i) for i in 1 : N])
504485 else
505- for i in dims
506- push! (rdims, i - 1 )
507- end
508- end
509-
510- in_tys = [
511- MLIR. IR. TensorType (Int64[], eltype (MLIR. IR. type (inp[1 ]))),
512- MLIR. IR. TensorType (Int64[], eltype (MLIR. IR. type (init[1 ]))),
513- ]
514-
515- fnbody = MLIR. IR. Block (in_tys, [MLIR. IR. Location (), MLIR. IR. Location ()])
516-
517- args = (
518- TracedRNumber {Reactant.unwrapped_eltype(op_in_T)} ((), MLIR. IR. argument (fnbody, 1 )),
519- TracedRNumber {Reactant.unwrapped_eltype(op_in_T)} ((), MLIR. IR. argument (fnbody, 2 )),
520- )
521-
522- resty = MLIR. IR. block! (fnbody) do
523- tmp = TracedUtils. broadcast_to_size (op (args... ), ())
524- Ops. return_ (tmp)
525- return eltype (MLIR. IR. type (tmp. mlir_data))
486+ TracedRNumber {unwrapped_eltype(reduction_result)} ((), reduction_result. mlir_data)
526487 end
527488
528- toonedims = Int[]
529- outdims = Int[]
530- for i in 1 : N
531- tmp = if in (i - 1 , rdims)
532- 1
533- else
534- sz = size (A, i)
535- push! (outdims, sz)
536- sz
537- end
538- push! (toonedims, tmp)
539- end
540-
541- TT = MLIR. IR. Type[MLIR. IR. TensorType (outdims, resty)]
542-
543- body = MLIR. IR. Region ()
544- push! (body, fnbody)
545- red = MLIR. Dialects. stablehlo. reduce (
546- inp, init; result_0= TT, dimensions= MLIR. IR. DenseArrayAttribute (rdims), body
547- )
548-
549- red = MLIR. IR. result (red, 1 )
550- redT = eltype (MLIR. IR. julia_type (MLIR. IR. type (red)))
551-
552- if dims != (:)
553- red = Ops. reshape (TracedRArray (red), toonedims... )
554- else
555- if length (outdims) == 0
556- red = TracedRNumber {redT} ((), red)
557- else
558- red = TracedRArray {redT,length(outdims)} ((), red, (outdims... ,))
559- end
560- end
561- return red
489+ init === nothing && return reduction_result
490+ return broadcast (op, reduction_result, init)
562491end
563492
564- function Base. mapreducedim ! (
493+ function Base. _mapreducedim ! (
565494 @nospecialize (f),
566495 @nospecialize (op),
567496 @nospecialize (R:: AnyTracedRArray ),
@@ -573,9 +502,9 @@ function Base.mapreducedim!(
573502 @assert sR == 1
574503 return i
575504 end
505+ isempty (A) && return R
576506 tmp = mapreduce (f, op, A; dims= filter (! isnothing, dims))
577- # set_mlir_data!(R, get_mlir_data(tmp))
578- R .= op .(R, tmp) # match native Julia's behavior
507+ R .= op .(R, tmp)
579508 return R
580509end
581510
0 commit comments