Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def _deserialize_ndarray_container( # type: ignore[misc]

result = type(template)(template.shape, dtype=object)
for i, subary in serialized:
result[i] = subary
# FIXME: numpy annotations don't seem to handle object arrays very well
result[i] = subary # type: ignore[call-overload]

return result

Expand Down
44 changes: 22 additions & 22 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,16 @@ def _map_array_container_impl(
specific container classes. By default, the recursion is stopped when
a non-:class:`ArrayContainer` class is encountered.
"""
def rec(_ary: ArrayOrContainer) -> ArrayOrContainer:
if type(_ary) is leaf_cls: # type(ary) is never None
return f(_ary)
def rec(ary_: ArrayOrContainer) -> ArrayOrContainer:
if type(ary_) is leaf_cls: # type(ary) is never None
return f(ary_)

try:
iterable = serialize_container(_ary)
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
return f(_ary)
return f(ary_)
else:
return deserialize_container(_ary, [
return deserialize_container(ary_, [
(key, frec(subary)) for key, subary in iterable
])

Expand All @@ -144,28 +144,28 @@ def _multimap_array_container_impl(

# {{{ recursive traversal

def rec(*_args: Any) -> Any:
template_ary = _args[container_indices[0]]
def rec(*args_: Any) -> Any:
template_ary = args_[container_indices[0]]
if type(template_ary) is leaf_cls:
return f(*_args)
return f(*args_)

try:
iterable_template = serialize_container(template_ary)
except NotAnArrayContainerError:
return f(*_args)
return f(*args_)
else:
pass

assert all(
type(_args[i]) is type(template_ary) for i in container_indices[1:]
type(args_[i]) is type(template_ary) for i in container_indices[1:]
), f"expected type '{type(template_ary).__name__}'"

result = []
new_args = list(_args)
new_args = list(args_)

for subarys in zip(
iterable_template,
*[serialize_container(_args[i]) for i in container_indices[1:]],
*[serialize_container(args_[i]) for i in container_indices[1:]],
strict=True
):
key = None
Expand Down Expand Up @@ -415,13 +415,13 @@ def rec_keyed_map_array_container(
"""

def rec(keys: tuple[SerializationKey, ...],
_ary: ArrayOrContainerT) -> ArrayOrContainerT:
ary_: ArrayOrContainerT) -> ArrayOrContainerT:
try:
iterable = serialize_container(_ary)
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
return cast(ArrayOrContainerT, f(keys, cast(ArrayT, _ary)))
return cast(ArrayOrContainerT, f(keys, cast(ArrayT, ary_)))
else:
return deserialize_container(_ary, [
return deserialize_container(ary_, [
(key, rec((*keys, key), subary)) for key, subary in iterable
])

Expand Down Expand Up @@ -522,14 +522,14 @@ def rec_map_reduce_array_container(

or any other such traversal.
"""
def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
if type(_ary) is leaf_class:
return map_func(_ary)
def rec(ary_: ArrayOrContainerT) -> ArrayOrContainerT:
if type(ary_) is leaf_class:
return map_func(ary_)
else:
try:
iterable = serialize_container(_ary)
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
return map_func(_ary)
return map_func(ary_)
else:
return reduce_func([
rec(subary) for _, subary in iterable
Expand Down
32 changes: 17 additions & 15 deletions arraycontext/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ def get_default_entrypoint(t_unit):
def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes):
@memoize_in(actx, _get_scalar_func_loopy_program)
def get(c_name, nargs, naxes):
from pymbolic import var
from pymbolic.primitives import Subscript, Variable

var_names = [f"i{i}" for i in range(naxes)]
size_names = [f"n{i}" for i in range(naxes)]
subscript = tuple(var(vname) for vname in var_names)
subscript = tuple(Variable(vname) for vname in var_names)

from islpy import make_zero_and_vars

v = make_zero_and_vars(var_names, params=size_names)
domain = v[0].domain()
for vname, sname in zip(var_names, size_names, strict=True):
Expand All @@ -98,22 +100,22 @@ def get(c_name, nargs, naxes):

import loopy as lp

from .loopy import make_loopy_program
from arraycontext.transform_metadata import ElementwiseMapKernelTag

def sub(name: str) -> Variable | Subscript:
return Subscript(Variable(name), subscript) if subscript else Variable(name)

return make_loopy_program(
[domain_bset],
[
[domain_bset], [
lp.Assignment(
var("out")[subscript],
var(c_name)(*[
var(f"inp{i}")[subscript] for i in range(nargs)]))
],
[
lp.GlobalArg("out",
dtype=None, shape=lp.auto, offset=lp.auto)] + [
lp.GlobalArg(f"inp{i}",
dtype=None, shape=lp.auto, offset=lp.auto)
for i in range(nargs)] + [...],
sub("out"),
Variable(c_name)(*[sub(f"inp{i}") for i in range(nargs)]))
], [
lp.GlobalArg("out", dtype=None, shape=lp.auto, offset=lp.auto)
] + [
lp.GlobalArg(f"inp{i}", dtype=None, shape=lp.auto, offset=lp.auto)
for i in range(nargs)
] + [...],
name=f"actx_special_{c_name}",
tags=(ElementwiseMapKernelTag(),))

Expand Down
8 changes: 4 additions & 4 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,11 @@ def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype):
"atan2": "arctan2",
}

def evaluate(_np, *_args):
func = getattr(_np, sym_name,
getattr(_np, c_to_numpy_arc_functions.get(sym_name, sym_name)))
def evaluate(np_, *args_):
func = getattr(np_, sym_name,
getattr(np_, c_to_numpy_arc_functions.get(sym_name, sym_name)))

return func(*_args)
return func(*args_)

assert_close_to_numpy_in_containers(actx, evaluate, args)

Expand Down
Loading