Skip to content
167 changes: 166 additions & 1 deletion pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
MemoryObject,
MemoryMap,
Buffer,
PooledBuffer,

_Program,
Kernel,
Expand Down Expand Up @@ -197,7 +198,7 @@
enqueue_migrate_mem_objects, unload_platform_compiler)

if get_cl_header_version() >= (2, 0):
from pyopencl._cl import SVM, SVMAllocation, SVMPointer
from pyopencl._cl import SVM, SVMAllocation, SVMPointer, PooledSVM

if _cl.have_gl():
from pyopencl._cl import ( # noqa: F401
Expand Down Expand Up @@ -2407,4 +2408,168 @@ def fsvm_empty_like(ctx, ary, alignment=None):
_KERNEL_ARG_CLASSES = (*_KERNEL_ARG_CLASSES, SVM)


# {{{ pickling support

import threading
from contextlib import contextmanager


_QUEUE_FOR_PICKLING_TLS = threading.local()


@contextmanager
def queue_for_pickling(queue, alloc=None):
r"""A context manager that, for the current thread, sets the command queue
to be used for pickling and unpickling :class:`Array`\ s and :class:`Buffer`\ s
to *queue*."""
try:
existing_pickle_queue = _QUEUE_FOR_PICKLING_TLS.queue
except AttributeError:
existing_pickle_queue = None

if existing_pickle_queue is not None:
raise RuntimeError("queue_for_pickling should not be called "
"inside the context of its own invocation.")

_QUEUE_FOR_PICKLING_TLS.queue = queue
_QUEUE_FOR_PICKLING_TLS.alloc = alloc
try:
yield None
finally:
_QUEUE_FOR_PICKLING_TLS.queue = None
_QUEUE_FOR_PICKLING_TLS.alloc = None


def _get_queue_for_pickling(obj):
try:
queue = _QUEUE_FOR_PICKLING_TLS.queue
alloc = _QUEUE_FOR_PICKLING_TLS.alloc
except AttributeError:
queue = None

if queue is None:
raise RuntimeError(f"{type(obj).__name__} instances can only be pickled while "
"queue_for_pickling is active.")

return queue, alloc


def _getstate_buffer(self):
import pyopencl as cl
queue, _alloc = _get_queue_for_pickling(self)

state = {}
state["size"] = self.size
state["flags"] = self.flags

a = bytearray(self.size)
cl.enqueue_copy(queue, a, self)

state["_pickle_data"] = a

return state


def _setstate_buffer(self, state):
import pyopencl as cl
queue, _alloc = _get_queue_for_pickling(self)

size = state["size"]
flags = state["flags"]

a = state["_pickle_data"]
Buffer.__init__(self, queue.context, flags | cl.mem_flags.COPY_HOST_PTR, size, a)


Buffer.__getstate__ = _getstate_buffer
Buffer.__setstate__ = _setstate_buffer


def _getstate_pooledbuffer(self):
import pyopencl as cl
queue, _alloc = _get_queue_for_pickling(self)

state = {}
state["size"] = self.size
state["flags"] = self.flags

a = bytearray(self.size)
cl.enqueue_copy(queue, a, self)
state["_pickle_data"] = a

return state


def _setstate_pooledbuffer(self, state):
_queue, _alloc = _get_queue_for_pickling(self)

_size = state["size"]
_flags = state["flags"]

_a = state["_pickle_data"]
# FIXME: Unclear what to do here - PooledBuffer does not have __init__


PooledBuffer.__getstate__ = _getstate_pooledbuffer
PooledBuffer.__setstate__ = _setstate_pooledbuffer


if get_cl_header_version() >= (2, 0):
def _getstate_svmallocation(self):
import pyopencl as cl

state = {}
state["size"] = self.size

queue, _alloc = _get_queue_for_pickling(self)

a = bytearray(self.size)
cl.enqueue_copy(queue, a, self)

state["_pickle_data"] = a

return state

def _setstate_svmallocation(self, state):
import pyopencl as cl

queue, _alloc = _get_queue_for_pickling(self)

size = state["size"]

a = state["_pickle_data"]
SVMAllocation.__init__(self, queue.context, size, alignment=0, flags=0,
queue=queue)
cl.enqueue_copy(queue, self, a)

SVMAllocation.__getstate__ = _getstate_svmallocation
SVMAllocation.__setstate__ = _setstate_svmallocation

def _getstate_pooled_svm(self):
import pyopencl as cl

state = {}
state["size"] = self.size

queue, _alloc = _get_queue_for_pickling(self)

a = bytearray(self.size)
cl.enqueue_copy(queue, a, self)

state["_pickle_data"] = a

return state

def _setstate_pooled_svm(self, state):
_queue, _alloc = _get_queue_for_pickling(self)
_size = state["size"]
_data = state["_pickle_data"]

# FIXME: Unclear what to do here - PooledSVM does not have __init__

PooledSVM.__getstate__ = _getstate_pooled_svm
PooledSVM.__setstate__ = _setstate_pooled_svm

# }}}

# vim: foldmethod=marker
41 changes: 41 additions & 0 deletions pyopencl/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,47 @@ def __init__(
"than expected, potentially leading to crashes.",
InconsistentOpenCLQueueWarning, stacklevel=2)

# {{{ Pickling

def __getstate__(self):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be useful if this worked for subclasses (liked TaggedCLArray), too. Should be tested, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of fdb3525 ?

try:
queue = cl._QUEUE_FOR_PICKLING_TLS.queue
except AttributeError:
queue = None

if queue is None:
raise RuntimeError("CL Array instances can only be pickled while "
"cl.queue_for_pickling is active.")

state = self.__dict__.copy()

del state["allocator"]
del state["context"]
del state["events"]
del state["queue"]
return state

def __setstate__(self, state):
try:
queue = cl._QUEUE_FOR_PICKLING_TLS.queue
alloc = cl._QUEUE_FOR_PICKLING_TLS.alloc
except AttributeError:
queue = None
alloc = None

if queue is None:
raise RuntimeError("CL Array instances can only be unpickled while "
"cl.queue_for_pickling is active.")

self.__dict__.update(state)

self.allocator = alloc
self.context = queue.context
self.events = []
self.queue = queue

# }}}

@property
def ndim(self):
return len(self.shape)
Expand Down
90 changes: 90 additions & 0 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2391,6 +2391,96 @@ def test_xdg_cache_home(ctx_factory):
# }}}


# {{{ test pickling

from pytools.tag import Taggable


class TaggableCLArray(cl_array.Array, Taggable):
def __init__(self, cq, shape, dtype, tags):
super().__init__(cq=cq, shape=shape, dtype=dtype)
self.tags = tags


@pytest.mark.parametrize("use_mempool", [False, True])
def test_array_pickling(ctx_factory, use_mempool):
context = ctx_factory()
queue = cl.CommandQueue(context)

if use_mempool:
alloc = cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue))
else:
alloc = None

a = np.array([1, 2, 3, 4, 5]).astype(np.float32)
a_gpu = cl_array.to_device(queue, a, allocator=alloc)

import pickle
with pytest.raises(RuntimeError):
pickle.dumps(a_gpu)

with cl.queue_for_pickling(queue):
a_gpu_pickled = pickle.loads(pickle.dumps(a_gpu))
assert np.all(a_gpu_pickled.get() == a)

# {{{ subclass test

a_gpu_tagged = TaggableCLArray(queue, a.shape, a.dtype, tags={"foo", "bar"})
a_gpu_tagged.set(a)

with cl.queue_for_pickling(queue):
a_gpu_tagged_pickled = pickle.loads(pickle.dumps(a_gpu_tagged))

assert np.all(a_gpu_tagged_pickled.get() == a)
assert a_gpu_tagged_pickled.tags == a_gpu_tagged.tags

# }}}

# {{{ SVM test

from pyopencl.characterize import has_coarse_grain_buffer_svm

if has_coarse_grain_buffer_svm(queue.device):
from pyopencl.tools import SVMAllocator, SVMPool

alloc = SVMAllocator(context, alignment=0, queue=queue)
if use_mempool:
alloc = SVMPool(alloc)

a_dev = cl_array.to_device(queue, a, allocator=alloc)

with cl.queue_for_pickling(queue, alloc):
a_dev_pickled = pickle.loads(pickle.dumps(a_dev))

assert np.all(a_dev_pickled.get() == a)
assert a_dev_pickled.allocator is alloc

# }}}


def test_buffer_pickling(ctx_factory):
context = ctx_factory()
queue = cl.CommandQueue(context)

a = np.array([1, 2, 3, 4, 5]).astype(np.float32)
a_gpu = cl.Buffer(context, cl.mem_flags.READ_WRITE, a.nbytes)
cl.enqueue_copy(queue, a_gpu, a)

import pickle

with pytest.raises(cl.RuntimeError):
pickle.dumps(a_gpu)

with cl.queue_for_pickling(queue):
a_gpu_pickled = pickle.loads(pickle.dumps(a_gpu))

a_new = np.empty_like(a)
cl.enqueue_copy(queue, a_new, a_gpu_pickled)
assert np.all(a_new == a)

# }}}


def test_numpy_type_promotion_with_cl_arrays(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
Expand Down
Loading