Skip to content

Commit 2e8f434

Browse files
committed
dtype numpy extension array
1 parent b623b4b commit 2e8f434

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

pandas-stubs/core/construction.pyi

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ from pandas._libs.tslibs.period import Period
4040
from pandas._libs.tslibs.timedeltas import Timedelta
4141
from pandas._libs.tslibs.timestamps import Timestamp
4242
from pandas._typing import (
43-
BuiltinFloatDtypeArg,
44-
BuiltinIntDtypeArg,
45-
BuiltinStrDtypeArg,
43+
BuiltinDtypeArg,
4644
CategoryDtypeArg,
4745
IntervalT,
48-
NumpyDtypeArg,
46+
NumpyNotTimeDtypeArg,
47+
NumpyTimedeltaDtypeArg,
48+
NumpyTimestampDtypeArg,
4949
PandasBooleanDtypeArg,
5050
PandasFloatDtypeArg,
5151
PandasIntDtypeArg,
@@ -119,15 +119,15 @@ def array( # type: ignore[overload-overlap] # pyright: ignore[reportOverlapping
119119
@overload
120120
def array( # type: ignore[overload-overlap]
121121
data: Sequence[int | np.integer | NAType | None] | np_ndarray_anyint | IntegerArray,
122-
dtype: BuiltinIntDtypeArg | PandasIntDtypeArg | PandasUIntDtypeArg | None = None,
122+
dtype: PandasIntDtypeArg | PandasUIntDtypeArg | None = None,
123123
copy: bool = True,
124124
) -> IntegerArray: ...
125125
@overload
126126
def array( # type: ignore[overload-overlap]
127127
data: (
128128
Sequence[float | np.floating | NAType | None] | np_ndarray_float | FloatingArray
129129
),
130-
dtype: BuiltinFloatDtypeArg | PandasFloatDtypeArg | None = None,
130+
dtype: PandasFloatDtypeArg | None = None,
131131
copy: bool = True,
132132
) -> FloatingArray: ...
133133
@overload
@@ -140,7 +140,7 @@ def array( # type: ignore[overload-overlap]
140140
| DatetimeIndex
141141
| Series[Timestamp]
142142
),
143-
dtype: PandasTimestampDtypeArg | None = None,
143+
dtype: PandasTimestampDtypeArg | NumpyTimestampDtypeArg | None = None,
144144
copy: bool = True,
145145
) -> DatetimeArray: ...
146146
@overload
@@ -152,13 +152,13 @@ def array(
152152
| TimedeltaIndex
153153
| Series[Timedelta]
154154
),
155-
dtype: None = None,
155+
dtype: NumpyTimedeltaDtypeArg | None = None,
156156
copy: bool = True,
157157
) -> TimedeltaArray: ...
158158
@overload
159159
def array( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
160160
data: SequenceNotStr[str | np.str_ | NAType | None] | np_ndarray_str | StringArray,
161-
dtype: BuiltinStrDtypeArg | PandasStrDtypeArg | None = None,
161+
dtype: PandasStrDtypeArg | None = None,
162162
copy: bool = True,
163163
) -> StringArray: ...
164164
@overload
@@ -175,7 +175,7 @@ def array( # type: ignore[overload-overlap]
175175
@overload
176176
def array(
177177
data: SequenceNotStr[object] | np_ndarray | NumpyExtensionArray | RangeIndex,
178-
dtype: NumpyDtypeArg | None = None,
178+
dtype: BuiltinDtypeArg | NumpyNotTimeDtypeArg | None = None,
179179
copy: bool = True,
180180
) -> NumpyExtensionArray: ...
181181
@overload

tests/arrays/test_numpy_extension_array.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import numpy as np
22
import pandas as pd
33
from pandas.core.arrays.numpy_ import NumpyExtensionArray
4+
import pytest
45
from typing_extensions import assert_type
56

6-
from tests import check
7+
from tests import (
8+
BuiltinDtypeArg,
9+
NumpyNotTimeDtypeArg,
10+
check,
11+
get_dtype,
12+
)
713

814

915
def test_constructor() -> None:
@@ -31,3 +37,17 @@ def test_constructor() -> None:
3137
assert_type(pd.array(pd.RangeIndex(0, 1)), NumpyExtensionArray),
3238
NumpyExtensionArray,
3339
)
40+
41+
42+
@pytest.mark.parametrize("dtype", get_dtype(BuiltinDtypeArg | NumpyNotTimeDtypeArg))
43+
def test_constructors_dtype(dtype: BuiltinDtypeArg | NumpyNotTimeDtypeArg):
44+
if dtype == "V" or "void" in str(dtype):
45+
check(
46+
assert_type(pd.array([b"1"], dtype=dtype), NumpyExtensionArray),
47+
NumpyExtensionArray,
48+
)
49+
else:
50+
check(
51+
assert_type(pd.array([1], dtype=dtype), NumpyExtensionArray),
52+
NumpyExtensionArray,
53+
)

0 commit comments

Comments
 (0)