@@ -601,29 +601,100 @@ def __getitem__(
601601 Returns self[key].
602602
603603 Some rules:
604- * Integers removes the dim.
605- * Slices and ellipsis maintains same dim.
606- * None adds a dim.
604+ * Integer, removes the dim.
605+ * Slice and ellipsis, maintains same dim.
606+ * None, adds a dim.
607+ * integer array, removes the dim if 1 sized(?)
608+ * boolean array,
607609 * tuple follows above but on that specific axis.
608610
609611 Examples
610612 --------
611613
612- 1D
614+ Basic indexing:
615+
616+ >>> x = NamedArray(("x",), np.array([3, 5, 7]))
617+ >>> x[0]
618+ <xarray.NamedArray ()> Size: 8B
619+ np.int64(3)
620+ >>> x[-1]
621+ <xarray.NamedArray ()> Size: 8B
622+ np.int64(7)
623+ >>> x[0:2]
624+ <xarray.NamedArray (x: 2)> Size: 16B
625+ array([3, 5])
626+ >>> x[None]
627+ <xarray.NamedArray (dim_1: 1, x: 3)> Size: 24B
628+ array([[3, 5, 7]])
629+ >>> x[...]
630+ <xarray.NamedArray (x: 3)> Size: 24B
631+ array([3, 5, 7])
613632
614- >>> x = NamedArray(("x",), np.array([0, 1, 2]))
615- >>> key = NamedArray(("x",), np.array([1, 0, 0], dtype=bool))
616- >>> xm = x[key]
617- >>> xm.dims, xm.shape
618- (('x',), (1,))
633+ Indexing with integer array:
619634
620- >>> x = NamedArray(("x",), np.array([0, 1, 2]))
621- >>> key = NamedArray(("x",), np.array([0, 0, 0], dtype=bool))
622- >>> xm = x[key]
623- >>> xm.dims, xm.shape
624- (('x',), (0,))
635+ >>> key = NamedArray(("x",), np.array([0, 0, 0, -1], dtype=int))
636+ >>> x[key]
637+ <xarray.NamedArray (x: 4)> Size: 32B
638+ array([3, 3, 3, 7])
639+
640+ Indexing with boolean array:
625641
626- Setup a ND array:
642+ >>> key = NamedArray(("x",), np.array([1, 0, 1], dtype=bool))
643+ >>> x[key]
644+ <xarray.NamedArray (x: 2)> Size: 16B
645+ array([3, 7])
646+ >>> key = NamedArray(("x",), np.array([0, 0, 0], dtype=bool))
647+ >>> x[key]
648+ <xarray.NamedArray (x: 0)> Size: 0B
649+ array([], dtype=int64)
650+
651+ Multidmensional basic indexing:
652+
653+ >>> x = NamedArray(("z", "y", "x"), np.array([[[1], [2], [3]], [[4], [5], [6]]]))
654+ >>> x[0, 0, 0]
655+ <xarray.NamedArray ()> Size: 8B
656+ np.int64(1)
657+ >>> x[1:2]
658+ <xarray.NamedArray (z: 1, y: 3, x: 1)> Size: 24B
659+ array([[[4],
660+ [5],
661+ [6]]])
662+ >>> x[:, None, :, :]
663+ <xarray.NamedArray (z: 2, dim_3: 1, y: 3, x: 1)> Size: 48B
664+ array([[[[1],
665+ [2],
666+ [3]]],
667+ <BLANKLINE>
668+ <BLANKLINE>
669+ [[[4],
670+ [5],
671+ [6]]]])
672+ >>> x[..., 0]
673+ <xarray.NamedArray (z: 2, y: 3)> Size: 48B
674+ array([[1, 2, 3],
675+ [4, 5, 6]])
676+
677+ Multidimensional indexing with integer array:
678+
679+ >>> x = NamedArray(("z", "y", "x"), np.array([[[1], [2], [3]], [[4], [5], [6]]]))
680+ >>> key = NamedArray(("z",), np.array([1, -1]))
681+ >>> x[key]
682+ <xarray.NamedArray (z: 2, y: 3, x: 1)> Size: 48B
683+ array([[[4],
684+ [5],
685+ [6]],
686+ <BLANKLINE>
687+ [[4],
688+ [5],
689+ [6]]])
690+ >>> x[key, 0]
691+ <xarray.NamedArray (z: 2, x: 1)> Size: 16B
692+ array([[4],
693+ [4]])
694+
695+ OLD
696+
697+ ND array:
627698
628699 >>> x = NamedArray(("x", "y"), np.arange(3 * 4).reshape((3, 4)))
629700 >>> xm = x[0]
@@ -635,20 +706,22 @@ def __getitem__(
635706 >>> xm = x[None]
636707 >>> xm.dims, xm.shape
637708 (('dim_2', 'x', 'y'), (1, 3, 4))
638-
639709 >>> key = NamedArray(("x", "y"), np.ones((3, 4), dtype=bool))
640710 >>> xm = x[key]
641711 >>> xm.dims, xm.shape
642712 ((('x', 'y'),), (12,))
643713
644- 0D
714+ 0D array:
645715
646716 >>> x = NamedArray((), np.array(False, dtype=np.bool))
647717 >>> key = NamedArray((), np.array(False, dtype=np.bool))
648718 >>> xm = x[key]
649719 >>> xm.dims, xm.shape
650720 (('dim_0',), (0,))
651721 """
722+ from xarray .namedarray ._array_api ._data_type_functions import (
723+ isdtype ,
724+ )
652725 from xarray .namedarray ._array_api ._manipulation_functions import (
653726 _broadcast_arrays ,
654727 )
@@ -658,16 +731,30 @@ def __getitem__(
658731 _flatten_dims ,
659732 )
660733
661- if isinstance (key , NamedArray ):
734+ if isinstance (key , NamedArray ) and isdtype ( key . dtype , "bool" ) :
662735 self_new , key_new = _broadcast_arrays (self , key )
736+ # self_new, key_new = self, key
663737 _data = self_new ._data [key_new ._data ]
664738 _dims = _flatten_dims (_atleast1d_dims (self_new .dims ))
739+ # _dims = dims
665740 return self ._new (_dims , _data )
666- elif isinstance (key , int | slice | tuple ) or key is None or key is ...:
741+ elif (
742+ isinstance (key , int | slice | tuple | NamedArray )
743+ or key is None
744+ or key is ...
745+ ):
667746 # TODO: __getitem__ not always available, use expand_dims
668- _data = self ._data [key ]
669747 _key_tuple = key if isinstance (key , tuple ) else (key ,)
670- _dims = _dims_from_tuple_indexing (self .dims , _key_tuple )
748+ # _dims = _dims_from_tuple_indexing(self.dims, _key_tuple)
749+ _dims = _dims_from_tuple_indexing (
750+ self .dims ,
751+ tuple (k .dims if isinstance (k , NamedArray ) else k for k in _key_tuple ),
752+ )
753+
754+ _data = self ._data [
755+ tuple (k ._data if isinstance (k , NamedArray ) else k for k in _key_tuple )
756+ ]
757+
671758 return self ._new (_dims , _data )
672759 else :
673760 raise NotImplementedError (f"{ key = } is not supported" )
0 commit comments