diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index afe4a406..e70b51df 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -309,7 +309,8 @@ def _deserialize_ndarray_container( # type: ignore[misc] result = type(template)(template.shape, dtype=object) for i, subary in serialized: - result[i] = subary + # FIXME: numpy annotations don't seem to handle object arrays very well + result[i] = subary # type: ignore[call-overload] return result diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 62f6354c..8abd73d9 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -110,16 +110,16 @@ def _map_array_container_impl( specific container classes. By default, the recursion is stopped when a non-:class:`ArrayContainer` class is encountered. """ - def rec(_ary: ArrayOrContainer) -> ArrayOrContainer: - if type(_ary) is leaf_cls: # type(ary) is never None - return f(_ary) + def rec(ary_: ArrayOrContainer) -> ArrayOrContainer: + if type(ary_) is leaf_cls: # type(ary) is never None + return f(ary_) try: - iterable = serialize_container(_ary) + iterable = serialize_container(ary_) except NotAnArrayContainerError: - return f(_ary) + return f(ary_) else: - return deserialize_container(_ary, [ + return deserialize_container(ary_, [ (key, frec(subary)) for key, subary in iterable ]) @@ -144,28 +144,28 @@ def _multimap_array_container_impl( # {{{ recursive traversal - def rec(*_args: Any) -> Any: - template_ary = _args[container_indices[0]] + def rec(*args_: Any) -> Any: + template_ary = args_[container_indices[0]] if type(template_ary) is leaf_cls: - return f(*_args) + return f(*args_) try: iterable_template = serialize_container(template_ary) except NotAnArrayContainerError: - return f(*_args) + return f(*args_) else: pass assert all( - type(_args[i]) is type(template_ary) for i in container_indices[1:] + type(args_[i]) is type(template_ary) for i in container_indices[1:] ), f"expected type '{type(template_ary).__name__}'" result = [] - new_args = list(_args) + new_args = list(args_) for subarys in zip( iterable_template, - *[serialize_container(_args[i]) for i in container_indices[1:]], + *[serialize_container(args_[i]) for i in container_indices[1:]], strict=True ): key = None @@ -415,13 +415,13 @@ def rec_keyed_map_array_container( """ def rec(keys: tuple[SerializationKey, ...], - _ary: ArrayOrContainerT) -> ArrayOrContainerT: + ary_: ArrayOrContainerT) -> ArrayOrContainerT: try: - iterable = serialize_container(_ary) + iterable = serialize_container(ary_) except NotAnArrayContainerError: - return cast(ArrayOrContainerT, f(keys, cast(ArrayT, _ary))) + return cast(ArrayOrContainerT, f(keys, cast(ArrayT, ary_))) else: - return deserialize_container(_ary, [ + return deserialize_container(ary_, [ (key, rec((*keys, key), subary)) for key, subary in iterable ]) @@ -522,14 +522,14 @@ def rec_map_reduce_array_container( or any other such traversal. """ - def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: - if type(_ary) is leaf_class: - return map_func(_ary) + def rec(ary_: ArrayOrContainerT) -> ArrayOrContainerT: + if type(ary_) is leaf_class: + return map_func(ary_) else: try: - iterable = serialize_container(_ary) + iterable = serialize_container(ary_) except NotAnArrayContainerError: - return map_func(_ary) + return map_func(ary_) else: return reduce_func([ rec(subary) for _, subary in iterable diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index d6f90783..af579324 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -83,12 +83,14 @@ def get_default_entrypoint(t_unit): def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes): @memoize_in(actx, _get_scalar_func_loopy_program) def get(c_name, nargs, naxes): - from pymbolic import var + from pymbolic.primitives import Subscript, Variable var_names = [f"i{i}" for i in range(naxes)] size_names = [f"n{i}" for i in range(naxes)] - subscript = tuple(var(vname) for vname in var_names) + subscript = tuple(Variable(vname) for vname in var_names) + from islpy import make_zero_and_vars + v = make_zero_and_vars(var_names, params=size_names) domain = v[0].domain() for vname, sname in zip(var_names, size_names, strict=True): @@ -98,22 +100,22 @@ def get(c_name, nargs, naxes): import loopy as lp - from .loopy import make_loopy_program from arraycontext.transform_metadata import ElementwiseMapKernelTag + + def sub(name: str) -> Variable | Subscript: + return Subscript(Variable(name), subscript) if subscript else Variable(name) + return make_loopy_program( - [domain_bset], - [ + [domain_bset], [ lp.Assignment( - var("out")[subscript], - var(c_name)(*[ - var(f"inp{i}")[subscript] for i in range(nargs)])) - ], - [ - lp.GlobalArg("out", - dtype=None, shape=lp.auto, offset=lp.auto)] + [ - lp.GlobalArg(f"inp{i}", - dtype=None, shape=lp.auto, offset=lp.auto) - for i in range(nargs)] + [...], + sub("out"), + Variable(c_name)(*[sub(f"inp{i}") for i in range(nargs)])) + ], [ + lp.GlobalArg("out", dtype=None, shape=lp.auto, offset=lp.auto) + ] + [ + lp.GlobalArg(f"inp{i}", dtype=None, shape=lp.auto, offset=lp.auto) + for i in range(nargs) + ] + [...], name=f"actx_special_{c_name}", tags=(ElementwiseMapKernelTag(),)) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index ab263304..14d24dd4 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -263,11 +263,11 @@ def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype): "atan2": "arctan2", } - def evaluate(_np, *_args): - func = getattr(_np, sym_name, - getattr(_np, c_to_numpy_arc_functions.get(sym_name, sym_name))) + def evaluate(np_, *args_): + func = getattr(np_, sym_name, + getattr(np_, c_to_numpy_arc_functions.get(sym_name, sym_name))) - return func(*_args) + return func(*args_) assert_close_to_numpy_in_containers(actx, evaluate, args)