|
1 | | -@generated function append_true(::Val{D},::Val{N}) where {D,N} |
| 1 | +@generated function append_true(::Val{D}, ::Val{N}) where {D,N} |
2 | 2 | length(D) == N && return D |
3 | 3 | t = Expr(:tuple) |
4 | | - for d = D |
| 4 | + for d in D |
5 | 5 | push!(t.args, d) |
6 | 6 | end |
7 | 7 | for n = length(D)+1:N |
|
12 | 12 | struct LowDimArray{D,T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} |
13 | 13 | data::A |
14 | 14 | function LowDimArray{D}(data::A) where {D,T,N,A<:AbstractArray{T,N}} |
15 | | - new{append_true(Val{D}(),Val{N}()),T,N,A}(data) |
| 15 | + new{append_true(Val{D}(), Val{N}()),T,N,A}(data) |
16 | 16 | end |
17 | 17 | function LowDimArray{D,T,N,A}(data::A) where {D,T,N,A<:AbstractArray{T,N}} |
18 | | - new{append_true(Val{D}(),Val{N}()),T,N,A}(data) |
| 18 | + new{append_true(Val{D}(), Val{N}()),T,N,A}(data) |
19 | 19 | end |
20 | 20 | end |
21 | | -function LowDimArray{D0}(data::LowDimArray{D1,T,N,A}) where {D0,T,N,D1,A<:AbstractArray{T,N}} |
22 | | - LowDimArray{map(|,D0,D1),T,N,A}(parent(data)) |
| 21 | +function LowDimArray{D0}( |
| 22 | + data::LowDimArray{D1,T,N,A}, |
| 23 | +) where {D0,T,N,D1,A<:AbstractArray{T,N}} |
| 24 | + LowDimArray{map(|, D0, D1),T,N,A}(parent(data)) |
23 | 25 | end |
24 | 26 | Base.@propagate_inbounds Base.getindex( |
25 | 27 | A::LowDimArray, |
|
115 | 117 | end |
116 | 118 | Expr(:block, Expr(:meta, :inline), staticexpr(Cnew)) |
117 | 119 | end |
118 | | -function ArrayInterface.contiguous_axis(::Type{LowDimArrayForBroadcast{D,T,N,A}}) where {D,T,N,A} |
| 120 | +function ArrayInterface.contiguous_axis( |
| 121 | + ::Type{LowDimArrayForBroadcast{D,T,N,A}}, |
| 122 | +) where {D,T,N,A} |
119 | 123 | ArrayInterface.contiguous_axis(A) |
120 | 124 | end |
121 | 125 | @inline function ArrayInterface.stride_rank( |
@@ -180,8 +184,8 @@ function _strides_expr(@nospecialize(s), @nospecialize(x), R::Vector{Int}, D::Ve |
180 | 184 | use_stride_acc = true |
181 | 185 | stride_acc::Int = 1 |
182 | 186 | if is_column_major(R) |
183 | | - # elseif is_row_major(R) |
184 | | - # Nrange = reverse(Nrange) |
| 187 | + # elseif is_row_major(R) |
| 188 | + # Nrange = reverse(Nrange) |
185 | 189 | else # not worth my time optimizing this case at the moment... |
186 | 190 | # will write something generic stride-rank agnostic eventually |
187 | 191 | use_stride_acc = false |
@@ -323,14 +327,8 @@ function add_broadcast!( |
323 | 327 | mA = gensym!(ls, "Aₘₖ") |
324 | 328 | mB = gensym!(ls, "Bₖₙ") |
325 | 329 | gf = GlobalRef(Core, :getfield) |
326 | | - pushprepreamble!( |
327 | | - ls, |
328 | | - Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a))), |
329 | | - ) |
330 | | - pushprepreamble!( |
331 | | - ls, |
332 | | - Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b))), |
333 | | - ) |
| 330 | + pushprepreamble!(ls, Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a)))) |
| 331 | + pushprepreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b)))) |
334 | 332 | pushprepreamble!(ls, Expr(:(=), Klen, Expr(:call, gf, Expr(:call, :size, mB), 1, false))) |
335 | 333 | pushpreamble!(ls, Expr(:(=), Krange, Expr(:call, :(:), staticexpr(1), Klen))) |
336 | 334 | k = gensym!(ls, "k") |
@@ -432,7 +430,7 @@ function add_broadcast!( |
432 | 430 | pushprepreamble!(ls, Expr(:(=), bcname2, Expr(:call, forbroadcast, lda))) |
433 | 431 | ArrayReference(bcname2, fulldims) |
434 | 432 | end |
435 | | - |
| 433 | + |
436 | 434 | loadop = add_simple_load!(ls, destname, ref, elementbytes, true)::Operation |
437 | 435 | doaddref!(ls, loadop) |
438 | 436 | end |
@@ -486,10 +484,7 @@ function add_broadcast!( |
486 | 484 | gf = GlobalRef(Core, :getfield) |
487 | 485 | for (i, arg) ∈ enumerate(args) |
488 | 486 | argname = gensym!(ls, "arg") |
489 | | - pushprepreamble!( |
490 | | - ls, |
491 | | - Expr(:(=), argname, Expr(:call, gf, bcargs, i, false)), |
492 | | - ) |
| 487 | + pushprepreamble!(ls, Expr(:(=), argname, Expr(:call, gf, bcargs, i, false))) |
493 | 488 | # dynamic dispatch |
494 | 489 | parent = add_broadcast!( |
495 | 490 | ls, |
|
542 | 537 | bc::BC, |
543 | 538 | ::Val{Mod}, |
544 | 539 | ::Val{UNROLL}, |
545 | | - ::Val{dontbc} |
| 540 | + ::Val{dontbc}, |
546 | 541 | ) where {T<:NativeTypes,N,BC<:Union{Broadcasted,Product},Mod,UNROLL,dontbc} |
547 | 542 | # 2 + 1 |
548 | 543 | # we have an N dimensional loop. |
|
580 | 575 | bc::BC, |
581 | 576 | ::Val{Mod}, |
582 | 577 | ::Val{UNROLL}, |
583 | | - ::Val{dontbc} |
584 | | -) where {T<:NativeTypes,N,A<:AbstractArray{T,N},BC<:Union{Broadcasted,Product},Mod,UNROLL,dontbc} |
| 578 | + ::Val{dontbc}, |
| 579 | +) where { |
| 580 | + T<:NativeTypes, |
| 581 | + N, |
| 582 | + A<:AbstractArray{T,N}, |
| 583 | + BC<:Union{Broadcasted,Product}, |
| 584 | + Mod, |
| 585 | + UNROLL, |
| 586 | + dontbc, |
| 587 | +} |
585 | 588 | # we have an N dimensional loop. |
586 | 589 | # need to construct the LoopSet |
587 | 590 | ls = LoopSet(Mod) |
@@ -626,14 +629,14 @@ end |
626 | 629 | bc::Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(identity),Tuple{T2}}, |
627 | 630 | ::Val{Mod}, |
628 | 631 | ::Val{UNROLL}, |
629 | | - ::Val{dontbc} |
| 632 | + ::Val{dontbc}, |
630 | 633 | ) where {T<:NativeTypes,N,T2<:Number,Mod,UNROLL,dontbc} |
631 | 634 | inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads, warncheckarg, safe = UNROLL |
632 | 635 | quote |
633 | 636 | $(Expr(:meta, :inline)) |
634 | 637 | arg = T(first(bc.args)) |
635 | 638 | @turbo inline = $inline unroll = ($u₁, $u₂) thread = $threads vectorize = $v for i ∈ |
636 | | - eachindex( |
| 639 | + eachindex( |
637 | 640 | dest, |
638 | 641 | ) |
639 | 642 | dest[i] = arg |
|
646 | 649 | bc::Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(identity),Tuple{T2}}, |
647 | 650 | ::Val{Mod}, |
648 | 651 | ::Val{UNROLL}, |
649 | | - ::Val{dontbc} |
| 652 | + ::Val{dontbc}, |
650 | 653 | ) where {T<:NativeTypes,N,A<:AbstractArray{T,N},T2<:Number,Mod,UNROLL,dontbc} |
651 | 654 | inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads, warncheckarg, safe = UNROLL |
652 | 655 | quote |
|
680 | 683 | end |
681 | 684 |
|
682 | 685 | # vmaterialize!(dest, bc, ::Val, ::Val, ::StaticInt, ::StaticInt, ::StaticInt) = |
683 | | - # Base.Broadcast.materialize!(dest, bc) |
| 686 | +# Base.Broadcast.materialize!(dest, bc) |
0 commit comments