Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SerializedArrays"
uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.2"
version = "0.1.3"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
module SerializedArraysLinearAlgebraExt

using LinearAlgebra: LinearAlgebra, mul!
using SerializedArrays: AbstractSerializedMatrix
using SerializedArrays: AbstractSerializedMatrix, memory

function mul_serialized!(
a_dest::AbstractMatrix, a1::AbstractMatrix, a2::AbstractMatrix, α::Number, β::Number
)
mul!(a_dest, memory(a1), memory(a2), α, β)
return a_dest
end

function LinearAlgebra.mul!(
a_dest::AbstractMatrix,
Expand All @@ -10,8 +17,27 @@ function LinearAlgebra.mul!(
α::Number,
β::Number,
)
mul!(a_dest, copy(a1), copy(a2), α, β)
return a_dest
return mul_serialized!(a_dest, a1, a2, α, β)
end

function LinearAlgebra.mul!(
a_dest::AbstractMatrix,
a1::AbstractMatrix,
a2::AbstractSerializedMatrix,
α::Number,
β::Number,
)
return mul_serialized!(a_dest, a1, a2, α, β)
end

function LinearAlgebra.mul!(
a_dest::AbstractMatrix,
a1::AbstractSerializedMatrix,
a2::AbstractMatrix,
α::Number,
β::Number,
)
return mul_serialized!(a_dest, a1, a2, α, β)
end

for f in [:eigen, :qr, :svd]
Expand Down
51 changes: 30 additions & 21 deletions src/SerializedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock!
using Serialization: deserialize, serialize

memory(a) = a

#
# AbstractSerializedArray
#
Expand All @@ -13,6 +15,9 @@
const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2}
const AbstractSerializedVector{T} = AbstractSerializedArray{T,1}

memory(a::AbstractSerializedArray) = copy(a)
disk(a::AbstractSerializedArray) = a

function _copyto_write!(dst, src)
writeblock!(dst, src, axes(src)...)
return dst
Expand All @@ -30,11 +35,11 @@
end
# Fix ambiguity error.
function Base.copyto!(dst::AbstractSerializedArray, src::AbstractSerializedArray)
return copyto!(dst, copy(src))
return copyto!(dst, memory(src))
end
# Fix ambiguity error.
function Base.copyto!(dst::AbstractDiskArray, src::AbstractSerializedArray)
return copyto!(dst, copy(src))
return copyto!(dst, memory(src))

Check warning on line 42 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L42

Added line #L42 was not covered by tests
end
# Fix ambiguity error.
function Base.copyto!(dst::AbstractSerializedArray, src::AbstractDiskArray)
Expand All @@ -45,26 +50,28 @@
return _copyto_read!(dst, src)
end

equals_serialized(a1, a2) = memory(a1) == memory(a2)

function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractSerializedArray)
return copy(a1) == copy(a2)
return equals_serialized(a1, a2)
end
function Base.:(==)(a1::AbstractArray, a2::AbstractSerializedArray)
return a1 == copy(a2)
return equals_serialized(a1, a2)
end
function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractArray)
return copy(a1) == a2
return equals_serialized(a1, a2)
end

# # These cause too many ambiguity errors, try bringing them back.
# function Base.convert(arrayt::Type{<:AbstractSerializedArray}, a::AbstractArray)
# return arrayt(a)
# end
# function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray)
# return convert(arrayt, copy(a))
# return convert(arrayt, memory(a))
# end
# # Fixes ambiguity error.
# function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray)
# return convert(arrayt, copy(a))
# return convert(arrayt, memory(a))
# end

#
Expand All @@ -79,6 +86,8 @@
Base.axes(a::SerializedArray) = getfield(a, :axes)
arraytype(a::SerializedArray{<:Any,<:Any,A}) where {A} = A

disk(a::AbstractArray) = SerializedArray(a)

function SerializedArray(file::String, a::AbstractArray)
serialize(file, a)
ax = axes(a)
Expand Down Expand Up @@ -114,10 +123,10 @@
a::SerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
) where {N}
if i == axes(a)
aout .= copy(a)
aout .= memory(a)
return a
end
aout .= @view copy(a)[i...]
aout .= @view memory(a)[i...]
return a
end
function DiskArrays.writeblock!(
Expand All @@ -127,7 +136,7 @@
serialize(file(a), ain)
return a
end
a′ = copy(a)
a′ = memory(a)
a′[i...] = ain
serialize(file(a), a′)
return a
Expand Down Expand Up @@ -171,7 +180,7 @@
end

function materialize(a::PermutedSerializedArray)
return PermutedDimsArray(copy(parent(a)), perm(a))
return PermutedDimsArray(memory(parent(a)), perm(a))
end
function Base.copy(a::PermutedSerializedArray)
return copy(materialize(a))
Expand Down Expand Up @@ -241,7 +250,7 @@
# friendly on GPU. Consider special cases of strded arrays
# and handle with stride manipulations.
function Base.copy(a::ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray})
a′ = reshape(copy(parent(a)), axes(a))
a′ = reshape(memory(parent(a)), axes(a))
return a′ isa Base.ReshapedArray ? copy(a′) : a′
end

Expand All @@ -254,10 +263,10 @@
a::ReshapedSerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
) where {N}
if i == axes(a)
aout .= copy(a)
aout .= memory(a)

Check warning on line 266 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L266

Added line #L266 was not covered by tests
return a
end
aout .= @view copy(a)[i...]
aout .= @view memory(a)[i...]

Check warning on line 269 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L269

Added line #L269 was not covered by tests
return nothing
end
function DiskArrays.writeblock!(
Expand All @@ -267,7 +276,7 @@
serialize(file(a), ain)
return a
end
a′ = copy(a)
a′ = memory(a)

Check warning on line 279 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L279

Added line #L279 was not covered by tests
a′[i...] = ain
serialize(file(a), a′)
return nothing
Expand Down Expand Up @@ -307,17 +316,17 @@
DiskArrays.haschunks(a::SubSerializedArray) = Unchunked()
function DiskArrays.readblock!(a::SubSerializedArray, aout, i::OrdinalRange...)
if i == axes(a)
aout .= copy(a)
aout .= memory(a)

Check warning on line 319 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L319

Added line #L319 was not covered by tests
end
aout[i...] = copy(view(a, i...))
aout[i...] = memory(view(a, i...))
return nothing
end
function DiskArrays.writeblock!(a::SubSerializedArray, ain, i::OrdinalRange...)
if i == axes(a)
serialize(file(a), ain)
return a
end
a_parent = copy(parent(a))
a_parent = memory(parent(a))
pinds = parentindices(view(a.sub_parent, i...))
a_parent[pinds...] = ain
serialize(file(a), a_parent)
Expand Down Expand Up @@ -349,7 +358,7 @@
end

function materialize(a::TransposeSerializedArray)
return transpose(copy(parent(a)))
return transpose(memory(parent(a)))
end
function Base.copy(a::TransposeSerializedArray)
return copy(materialize(a))
Expand Down Expand Up @@ -392,7 +401,7 @@
end

function materialize(a::AdjointSerializedArray)
return adjoint(copy(parent(a)))
return adjoint(memory(parent(a)))
end
function Base.copy(a::AdjointSerializedArray)
return copy(materialize(a))
Expand Down Expand Up @@ -445,7 +454,7 @@
Base.broadcastable(a::BroadcastSerializedArray) = a.broadcasted
function Base.copy(a::BroadcastSerializedArray)
# Broadcast over the materialized arrays.
return copy(Base.Broadcast.broadcasted(a.broadcasted.f, copy.(a.broadcasted.args)...))
return copy(Base.Broadcast.broadcasted(a.broadcasted.f, memory.(a.broadcasted.args)...))
end

function Base.copy(broadcasted::Broadcasted{SerializedArrayStyle{N}}) where {N}
Expand Down
10 changes: 9 additions & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ using SerializedArrays:
ReshapedSerializedArray,
SerializedArray,
SubSerializedArray,
TransposeSerializedArray
TransposeSerializedArray,
disk,
memory
using StableRNGs: StableRNG
using Test: @test, @testset
using TestExtras: @constinferred
Expand All @@ -21,6 +23,12 @@ arrayts = (Array, JLArray)
a = SerializedArray(x)
@test @constinferred(copy(a)) == x
@test typeof(copy(a)) == typeof(x)
@test memory(a) == x
@test memory(a) isa arrayt{elt,2}
@test memory(x) === x
@test disk(a) === a
@test disk(x) == a
@test disk(x) isa SerializedArray{elt,2,<:arrayt{elt,2}}

x = arrayt(zeros(elt, 4, 4))
a = SerializedArray(x)
Expand Down
8 changes: 8 additions & 0 deletions test/test_linearalgebraext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ arrayts = (Array, JLArray)
@test c == x * y
@test c isa arrayt{elt,2}

c = @constinferred(x * b)
@test c == x * y
@test c isa arrayt{elt,2}

c = @constinferred(a * y)
@test c == x * y
@test c isa arrayt{elt,2}

a = permutedims(SerializedArray(x), (2, 1))
b = permutedims(SerializedArray(y), (2, 1))
c = @constinferred(a * b)
Expand Down
Loading