Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SerializedArrays"
uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.3"
version = "0.2.0"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ SerializedArrays = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
[compat]
Documenter = "1"
Literate = "2"
SerializedArrays = "0.1"
SerializedArrays = "0.2"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
SerializedArrays = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"

[compat]
SerializedArrays = "0.1"
SerializedArrays = "0.2"
109 changes: 57 additions & 52 deletions src/SerializedArrays.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
module SerializedArrays

export SerializedArray, disk, memory

using Base.PermutedDimsArrays: genperm
using ConstructionBase: constructorof
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock!
using Serialization: deserialize, serialize

memory(a) = a
adapt_serialized(to, x) = adapt_structure_serialized(to, x)
adapt_serialized(to) = Base.Fix1(adapt_structure_serialized, to)
adapt_structure_serialized(to, x) = adapt_storage_serialized(to, x)
adapt_storage_serialized(to, x) = x

struct DeepMemoryAdaptor end
deepmemory(x) = adapt_serialized(DeepMemoryAdaptor(), x)

struct MemoryAdaptor end
memory(x) = adapt_serialized(MemoryAdaptor(), x)

#
# AbstractSerializedArray
Expand All @@ -15,9 +26,12 @@
const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2}
const AbstractSerializedVector{T} = AbstractSerializedArray{T,1}

memory(a::AbstractSerializedArray) = copy(a)
disk(a::AbstractSerializedArray) = a

function Base.copy(a::AbstractSerializedArray)
return copy(memory(a))
end

function _copyto_write!(dst, src)
writeblock!(dst, src, axes(src)...)
return dst
Expand Down Expand Up @@ -62,18 +76,6 @@
return equals_serialized(a1, a2)
end

# # These cause too many ambiguity errors, try bringing them back.
# function Base.convert(arrayt::Type{<:AbstractSerializedArray}, a::AbstractArray)
# return arrayt(a)
# end
# function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray)
# return convert(arrayt, memory(a))
# end
# # Fixes ambiguity error.
# function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray)
# return convert(arrayt, memory(a))
# end

#
# SerializedArray
#
Expand Down Expand Up @@ -105,11 +107,19 @@
return constructorof(arraytype(a)){elt}(undef, dims...)
end

function materialize(a::SerializedArray)
function _memory(a::SerializedArray)
return deserialize(file(a))::arraytype(a)
end

function adapt_storage_serialized(::DeepMemoryAdaptor, a::SerializedArray)
return _memory(a)
end
function adapt_storage_serialized(::MemoryAdaptor, a::SerializedArray)
return _memory(a)
end

function Base.copy(a::SerializedArray)
return materialize(a)
return memory(a)
end

Base.size(a::SerializedArray) = length.(axes(a))
Expand All @@ -123,7 +133,7 @@
a::SerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
) where {N}
if i == axes(a)
aout .= memory(a)
aout .= deepmemory(a)
return a
end
aout .= @view memory(a)[i...]
Expand Down Expand Up @@ -179,11 +189,13 @@
return similar(parent(a), elt, dims)
end

function materialize(a::PermutedSerializedArray)
return PermutedDimsArray(memory(parent(a)), perm(a))
function adapt_structure_serialized(to, a::PermutedSerializedArray)
return PermutedDimsArray(adapt_serialized(to, parent(a)), perm(a))
end
function Base.copy(a::PermutedSerializedArray)
return copy(materialize(a))

# Special case to eagerly instantiate permutations.
function adapt_structure_serialized(to::MemoryAdaptor, a::PermutedSerializedArray)
return copy(deepmemory(a))
end

haschunks(a::PermutedSerializedArray) = Unchunked()
Expand Down Expand Up @@ -238,19 +250,14 @@
return similar(parent(a), elt, dims)
end

function materialize(a::ReshapedSerializedArray)
return reshape(materialize(parent(a)), axes(a))
function adapt_structure_serialized(to, a::ReshapedSerializedArray)
return reshape(adapt_serialized(to, parent(a)), axes(a))
end
function Base.copy(a::ReshapedSerializedArray)
a′ = materialize(a)
return a′ isa Base.ReshapedArray ? copy(a′) : a′
end

# Special case for handling nested wrappers that aren't
# friendly on GPU. Consider special cases of strded arrays
# and handle with stride manipulations.
function Base.copy(a::ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray})
a′ = reshape(memory(parent(a)), axes(a))
# `memory` instantiates `PermutedSerializedArray`, which is
# friendlier for GPU. Consider special cases of strded arrays
# and handle with stride manipulations.
a′ = memory(a)
return a′ isa Base.ReshapedArray ? copy(a′) : a′
end

Expand Down Expand Up @@ -306,17 +313,14 @@
Base.parent(a::SubSerializedArray) = parent(a.sub_parent)
Base.parentindices(a::SubSerializedArray) = parentindices(a.sub_parent)

function materialize(a::SubSerializedArray)
return view(copy(parent(a)), parentindices(a)...)
end
function Base.copy(a::SubSerializedArray)
return copy(materialize(a))
function adapt_structure_serialized(to, a::SubSerializedArray)
return view(adapt_serialized(to, parent(a)), parentindices(a)...)
end

DiskArrays.haschunks(a::SubSerializedArray) = Unchunked()
function DiskArrays.readblock!(a::SubSerializedArray, aout, i::OrdinalRange...)
if i == axes(a)
aout .= memory(a)
aout .= deepmemory(a)

Check warning on line 323 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L323

Added line #L323 was not covered by tests
end
aout[i...] = memory(view(a, i...))
return nothing
Expand All @@ -326,7 +330,7 @@
serialize(file(a), ain)
return a
end
a_parent = memory(parent(a))
a_parent = deepmemory(parent(a))
pinds = parentindices(view(a.sub_parent, i...))
a_parent[pinds...] = ain
serialize(file(a), a_parent)
Expand Down Expand Up @@ -357,11 +361,8 @@
return similar(parent(a), elt, dims)
end

function materialize(a::TransposeSerializedArray)
return transpose(memory(parent(a)))
end
function Base.copy(a::TransposeSerializedArray)
return copy(materialize(a))
function adapt_structure_serialized(to, a::TransposeSerializedArray)
return transpose(adapt_serialized(to, parent(a)))
end

haschunks(a::TransposeSerializedArray) = Unchunked()
Expand Down Expand Up @@ -400,11 +401,8 @@
return similar(parent(a), elt, dims)
end

function materialize(a::AdjointSerializedArray)
return adjoint(memory(parent(a)))
end
function Base.copy(a::AdjointSerializedArray)
return copy(materialize(a))
function adapt_structure_serialized(to, a::AdjointSerializedArray)
return adjoint(adapt_serialized(to, parent(a)))
end

haschunks(a::AdjointSerializedArray) = Unchunked()
Expand Down Expand Up @@ -452,9 +450,16 @@
end
Base.size(a::BroadcastSerializedArray) = size(a.broadcasted)
Base.broadcastable(a::BroadcastSerializedArray) = a.broadcasted
function Base.copy(a::BroadcastSerializedArray)
# Broadcast over the materialized arrays.
return copy(Base.Broadcast.broadcasted(a.broadcasted.f, memory.(a.broadcasted.args)...))

function adapt_structure_serialized(to, a::BroadcastSerializedArray)
return Base.Broadcast.broadcasted(
a.broadcasted.f, map(adapt_serialized(to), a.broadcasted.args)...
)
end

# Special case to eagerly instantiate broadcasts.
function adapt_storage_serialized(::MemoryAdaptor, a::BroadcastSerializedArray)
return copy(a)

Check warning on line 462 in src/SerializedArrays.jl

View check run for this annotation

Codecov / codecov/patch

src/SerializedArrays.jl#L461-L462

Added lines #L461 - L462 were not covered by tests
end

function Base.copy(broadcasted::Broadcasted{SerializedArrayStyle{N}}) where {N}
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ GPUArraysCore = "0.2"
JLArrays = "0.2"
LinearAlgebra = "1.10"
SafeTestsets = "0.1"
SerializedArrays = "0.1"
SerializedArrays = "0.2"
StableRNGs = "1"
Suppressor = "0.2"
Test = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ arrayts = (Array, JLArray)
rng = StableRNG(123)
x = arrayt(randn(rng, elt, 4, 4))
y = @view x[2:3, 2:3]
a = SerializedArray(a)
a = SerializedArray(x)
b = @view a[2:3, 2:3]
@test b isa SubSerializedArray{elt,2}
c = 2b
Expand Down
Loading