Skip to content

Commit b8106f2

Browse files
committed
Try to handle integer arrays in indexing
1 parent aef66ab commit b8106f2

File tree

3 files changed

+127
-29
lines changed

3 files changed

+127
-29
lines changed

xarray/namedarray/_array_api/_utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_DimsLike2,
2121
_dtype,
2222
_IndexKeys,
23+
_IndexKeysDims,
2324
_IndexKeysNoEllipsis,
2425
_Shape,
2526
)
@@ -488,7 +489,7 @@ def _check_indexing_dims(original_dims: _Dims, indexing_dims: _Dims) -> None:
488489
)
489490

490491

491-
def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims:
492+
def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeysDims) -> _Dims:
492493
"""
493494
Get the expected dims when using tuples in __getitem__.
494495
@@ -517,15 +518,18 @@ def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims:
517518
>>> _dims_from_tuple_indexing(("x",), (..., 0))
518519
()
519520
520-
Indexing array
521+
Indexing array is converted to dims
521522
522523
>>> import numpy as np
523-
>>> key = (0, NamedArray((), np.array(0, dtype=int)))
524+
>>> key = (0, NamedArray((), np.array(0, dtype=int)).dims)
524525
>>> _dims_from_tuple_indexing(("x", "y", "z"), key)
525526
('z',)
526-
>>> key = (0, NamedArray(("y",), np.array([0], dtype=int)))
527+
>>> key = (0, NamedArray(("y",), np.array([0], dtype=int)).dims)
527528
>>> _dims_from_tuple_indexing(("x", "y", "z"), key)
528529
('y', 'z')
530+
>>> key = (NamedArray(("x",), np.array([0])).dims, 0)
531+
>>> _dims_from_tuple_indexing(("x", "y", "z"), key)
532+
('x', 'z')
529533
"""
530534
key_no_ellipsis = _replace_ellipsis(key, len(dims))
531535

@@ -542,14 +546,20 @@ def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims:
542546
elif isinstance(k, slice):
543547
# Slice retains the dimension.
544548
j += 1
545-
elif isinstance(k, NamedArray):
546-
if len(k.dims) == 0:
549+
elif isinstance(k, tuple):
550+
_dims = k
551+
if len(_dims) == 0:
547552
# if 0 dim, removes 1 dimension
548553
new_dims.pop(j)
549-
else:
554+
elif len(_dims) == 1:
550555
# same size retains the dimension:
551-
_check_indexing_dims(dims[i : i + 1], k.dims)
556+
_check_indexing_dims(dims[i : i + 1], _dims)
552557
j += 1
558+
# new_dims.pop(j)
559+
else:
560+
raise NotImplementedError(
561+
f"What happens here? {key_no_ellipsis=}, {dims=}, {i=}, {k=}"
562+
)
553563

554564
return tuple(new_dims)
555565

xarray/namedarray/_typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def imag(self) -> _T_co: ...
143143
_IndexKey = Union[_IndexKeyNoEllipsis, EllipsisType]
144144
_IndexKeysNoEllipsis = tuple[_IndexKeyNoEllipsis, ...]
145145
_IndexKeys = tuple[_IndexKey, ...] # tuple[Union[_IndexKey, None], ...]
146+
_IndexKeysDims = tuple[Union[_IndexKey, _Dims], ...]
146147
_IndexKeyLike = Union[_IndexKey, _IndexKeys]
147148

148149
_AttrsLike = Union[Mapping[Any, Any], None]

xarray/namedarray/core.py

Lines changed: 108 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -601,29 +601,100 @@ def __getitem__(
601601
Returns self[key].
602602
603603
Some rules:
604-
* Integers removes the dim.
605-
* Slices and ellipsis maintains same dim.
606-
* None adds a dim.
604+
* Integer, removes the dim.
605+
* Slice and ellipsis, maintains same dim.
606+
* None, adds a dim.
607+
* integer array, removes the dim if 1 sized(?)
608+
* boolean array,
607609
* tuple follows above but on that specific axis.
608610
609611
Examples
610612
--------
611613
612-
1D
614+
Basic indexing:
615+
616+
>>> x = NamedArray(("x",), np.array([3, 5, 7]))
617+
>>> x[0]
618+
<xarray.NamedArray ()> Size: 8B
619+
np.int64(3)
620+
>>> x[-1]
621+
<xarray.NamedArray ()> Size: 8B
622+
np.int64(7)
623+
>>> x[0:2]
624+
<xarray.NamedArray (x: 2)> Size: 16B
625+
array([3, 5])
626+
>>> x[None]
627+
<xarray.NamedArray (dim_1: 1, x: 3)> Size: 24B
628+
array([[3, 5, 7]])
629+
>>> x[...]
630+
<xarray.NamedArray (x: 3)> Size: 24B
631+
array([3, 5, 7])
613632
614-
>>> x = NamedArray(("x",), np.array([0, 1, 2]))
615-
>>> key = NamedArray(("x",), np.array([1, 0, 0], dtype=bool))
616-
>>> xm = x[key]
617-
>>> xm.dims, xm.shape
618-
(('x',), (1,))
633+
Indexing with integer array:
619634
620-
>>> x = NamedArray(("x",), np.array([0, 1, 2]))
621-
>>> key = NamedArray(("x",), np.array([0, 0, 0], dtype=bool))
622-
>>> xm = x[key]
623-
>>> xm.dims, xm.shape
624-
(('x',), (0,))
635+
>>> key = NamedArray(("x",), np.array([0, 0, 0, -1], dtype=int))
636+
>>> x[key]
637+
<xarray.NamedArray (x: 4)> Size: 32B
638+
array([3, 3, 3, 7])
639+
640+
Indexing with boolean array:
625641
626-
Setup a ND array:
642+
>>> key = NamedArray(("x",), np.array([1, 0, 1], dtype=bool))
643+
>>> x[key]
644+
<xarray.NamedArray (x: 2)> Size: 16B
645+
array([3, 7])
646+
>>> key = NamedArray(("x",), np.array([0, 0, 0], dtype=bool))
647+
>>> x[key]
648+
<xarray.NamedArray (x: 0)> Size: 0B
649+
array([], dtype=int64)
650+
651+
Multidmensional basic indexing:
652+
653+
>>> x = NamedArray(("z", "y", "x"), np.array([[[1], [2], [3]], [[4], [5], [6]]]))
654+
>>> x[0, 0, 0]
655+
<xarray.NamedArray ()> Size: 8B
656+
np.int64(1)
657+
>>> x[1:2]
658+
<xarray.NamedArray (z: 1, y: 3, x: 1)> Size: 24B
659+
array([[[4],
660+
[5],
661+
[6]]])
662+
>>> x[:, None, :, :]
663+
<xarray.NamedArray (z: 2, dim_3: 1, y: 3, x: 1)> Size: 48B
664+
array([[[[1],
665+
[2],
666+
[3]]],
667+
<BLANKLINE>
668+
<BLANKLINE>
669+
[[[4],
670+
[5],
671+
[6]]]])
672+
>>> x[..., 0]
673+
<xarray.NamedArray (z: 2, y: 3)> Size: 48B
674+
array([[1, 2, 3],
675+
[4, 5, 6]])
676+
677+
Multidimensional indexing with integer array:
678+
679+
>>> x = NamedArray(("z", "y", "x"), np.array([[[1], [2], [3]], [[4], [5], [6]]]))
680+
>>> key = NamedArray(("z",), np.array([1, -1]))
681+
>>> x[key]
682+
<xarray.NamedArray (z: 2, y: 3, x: 1)> Size: 48B
683+
array([[[4],
684+
[5],
685+
[6]],
686+
<BLANKLINE>
687+
[[4],
688+
[5],
689+
[6]]])
690+
>>> x[key, 0]
691+
<xarray.NamedArray (z: 2, x: 1)> Size: 16B
692+
array([[4],
693+
[4]])
694+
695+
OLD
696+
697+
ND array:
627698
628699
>>> x = NamedArray(("x", "y"), np.arange(3 * 4).reshape((3, 4)))
629700
>>> xm = x[0]
@@ -635,20 +706,22 @@ def __getitem__(
635706
>>> xm = x[None]
636707
>>> xm.dims, xm.shape
637708
(('dim_2', 'x', 'y'), (1, 3, 4))
638-
639709
>>> key = NamedArray(("x", "y"), np.ones((3, 4), dtype=bool))
640710
>>> xm = x[key]
641711
>>> xm.dims, xm.shape
642712
((('x', 'y'),), (12,))
643713
644-
0D
714+
0D array:
645715
646716
>>> x = NamedArray((), np.array(False, dtype=np.bool))
647717
>>> key = NamedArray((), np.array(False, dtype=np.bool))
648718
>>> xm = x[key]
649719
>>> xm.dims, xm.shape
650720
(('dim_0',), (0,))
651721
"""
722+
from xarray.namedarray._array_api._data_type_functions import (
723+
isdtype,
724+
)
652725
from xarray.namedarray._array_api._manipulation_functions import (
653726
_broadcast_arrays,
654727
)
@@ -658,16 +731,30 @@ def __getitem__(
658731
_flatten_dims,
659732
)
660733

661-
if isinstance(key, NamedArray):
734+
if isinstance(key, NamedArray) and isdtype(key.dtype, "bool"):
662735
self_new, key_new = _broadcast_arrays(self, key)
736+
# self_new, key_new = self, key
663737
_data = self_new._data[key_new._data]
664738
_dims = _flatten_dims(_atleast1d_dims(self_new.dims))
739+
# _dims = dims
665740
return self._new(_dims, _data)
666-
elif isinstance(key, int | slice | tuple) or key is None or key is ...:
741+
elif (
742+
isinstance(key, int | slice | tuple | NamedArray)
743+
or key is None
744+
or key is ...
745+
):
667746
# TODO: __getitem__ not always available, use expand_dims
668-
_data = self._data[key]
669747
_key_tuple = key if isinstance(key, tuple) else (key,)
670-
_dims = _dims_from_tuple_indexing(self.dims, _key_tuple)
748+
# _dims = _dims_from_tuple_indexing(self.dims, _key_tuple)
749+
_dims = _dims_from_tuple_indexing(
750+
self.dims,
751+
tuple(k.dims if isinstance(k, NamedArray) else k for k in _key_tuple),
752+
)
753+
754+
_data = self._data[
755+
tuple(k._data if isinstance(k, NamedArray) else k for k in _key_tuple)
756+
]
757+
671758
return self._new(_dims, _data)
672759
else:
673760
raise NotImplementedError(f"{key=} is not supported")

0 commit comments

Comments
 (0)