Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions arraycontext/impl/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations


"""
__doc__ = """
.. currentmodule:: arraycontext

A mod :`numpy`-based array context.
A :mod:`numpy`-based array context.

.. autoclass:: NumpyArrayContext
"""
Expand Down
14 changes: 10 additions & 4 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
__doc__ = """
.. currentmodule:: arraycontext

A :mod:`pytato`-based array context defers the evaluation of an array until its
A :mod:`pytato`-based array context defers the evaluation of an array until it is
frozen. The execution contexts for the evaluations are specific to an
:class:`~arraycontext.ArrayContext` type. For ex.
:class:`~arraycontext.ArrayContext` type. For example,
:class:`~arraycontext.PytatoPyOpenCLArrayContext` uses :mod:`pyopencl` to
JIT-compile and execute the array expressions.

Following :mod:`pytato`-based array context are provided:
The following :mod:`pytato`-based array contexts are provided:

.. autoclass:: PytatoPyOpenCLArrayContext
.. autoclass:: PytatoJAXArrayContext
Expand All @@ -20,6 +20,12 @@
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. automodule:: arraycontext.impl.pytato.compile


Utils
^^^^^

.. automodule:: arraycontext.impl.pytato.utils
"""
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
Expand Down Expand Up @@ -227,7 +233,7 @@ def get_target(self):

class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
"""
A :class:`ArrayContext` that uses :mod:`pytato` data types to represent
An :class:`ArrayContext` that uses :mod:`pytato` data types to represent
the arrays targeting OpenCL for offloading operations.

.. attribute:: queue
Expand Down
92 changes: 91 additions & 1 deletion arraycontext/impl/pytato/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
__doc__ = """
.. autofunction:: transfer_from_numpy
.. autofunction:: transfer_to_numpy
"""


__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
Expand All @@ -22,6 +28,7 @@
THE SOFTWARE.
"""


from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast

Expand All @@ -36,9 +43,10 @@
make_placeholder,
)
from pytato.target.loopy import LoopyPyOpenCLTarget
from pytato.transform import CopyMapper
from pytato.transform import ArrayOrNames, CopyMapper
from pytools import UniqueNameGenerator, memoize_method

from arraycontext import ArrayContext
from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis


Expand Down Expand Up @@ -125,4 +133,86 @@ def get_loopy_target(self) -> "lp.PyOpenCLTarget":

# }}}


# {{{ Transfer mappers

class TransferFromNumpyMapper(CopyMapper):
"""A mapper to transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be device arrays, using
:meth:`~arraycontext.ArrayContext.from_numpy`.
"""
def __init__(self, actx: ArrayContext) -> None:
super().__init__()
self.actx = actx

def map_data_wrapper(self, expr: DataWrapper) -> Array:
import numpy as np

if not isinstance(expr.data, np.ndarray):
raise ValueError("TransferFromNumpyMapper: tried to transfer data that "
"is already on the device")

# Ideally, this code should just do
# return self.actx.from_numpy(expr.data).tagged(expr.tags),
# but there seems to be no way to transfer the non_equality_tags in that case.
new_dw = self.actx.from_numpy(expr.data)
assert isinstance(new_dw, DataWrapper)

# https://github.com/pylint-dev/pylint/issues/3893
# pylint: disable=unexpected-keyword-arg
return DataWrapper(
data=new_dw.data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)


class TransferToNumpyMapper(CopyMapper):
"""A mapper to transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be :class:`numpy.ndarray` instances, using
:meth:`~arraycontext.ArrayContext.to_numpy`.
"""
def __init__(self, actx: ArrayContext) -> None:
super().__init__()
self.actx = actx

def map_data_wrapper(self, expr: DataWrapper) -> Array:
import numpy as np

import arraycontext.impl.pyopencl.taggable_cl_array as tga
if not isinstance(expr.data, tga.TaggableCLArray):
raise ValueError("TransferToNumpyMapper: tried to transfer data that "
"is already on the host")

np_data = self.actx.to_numpy(expr.data)
assert isinstance(np_data, np.ndarray)

# https://github.com/pylint-dev/pylint/issues/3893
# pylint: disable=unexpected-keyword-arg
return DataWrapper(
data=np_data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)


def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be device arrays, using
:meth:`~arraycontext.ArrayContext.from_numpy`.
"""
return TransferFromNumpyMapper(actx)(expr)


def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be :class:`numpy.ndarray` instances, using
:meth:`~arraycontext.ArrayContext.to_numpy`.
"""
return TransferToNumpyMapper(actx)(expr)

# }}}

# vim: foldmethod=marker
53 changes: 53 additions & 0 deletions test/test_pytato_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,59 @@ def twice(x):
assert res == 198


def test_transfer(actx_factory):
import numpy as np

import pytato as pt
actx = actx_factory()

# {{{ simple tests

a = actx.from_numpy(np.array([0, 1, 2, 3])).tagged(FooTag())

from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
assert isinstance(a.data, TaggableCLArray)

from arraycontext.impl.pytato.utils import transfer_from_numpy, transfer_to_numpy

ah = transfer_to_numpy(a, actx)
assert ah != a
assert a.tags == ah.tags
assert a.non_equality_tags == ah.non_equality_tags
assert isinstance(ah.data, np.ndarray)

with pytest.raises(ValueError):
_ahh = transfer_to_numpy(ah, actx)

ad = transfer_from_numpy(ah, actx)
assert isinstance(ad.data, TaggableCLArray)
assert ad != ah
assert ad != a # copied DataWrappers compare unequal
assert ad.tags == ah.tags
assert ad.non_equality_tags == ah.non_equality_tags
assert np.array_equal(a.data.get(), ad.data.get())

with pytest.raises(ValueError):
_add = transfer_from_numpy(ad, actx)

# }}}

# {{{ test with DictOfNamedArrays

dag = pt.make_dict_of_named_arrays({
"a_expr": a + 2
})

dagh = transfer_to_numpy(dag, actx)
assert dagh != dag
assert isinstance(dagh["a_expr"].expr.bindings["_in0"].data, np.ndarray)

daghd = transfer_from_numpy(dagh, actx)
assert isinstance(daghd["a_expr"].expr.bindings["_in0"].data, TaggableCLArray)

# }}}


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down