Skip to content

Commit 9df4590

Browse files
committed
fix: type infer
1 parent 5815166 commit 9df4590

File tree

3 files changed

+65
-3
lines changed

3 files changed

+65
-3
lines changed

python/pyspark/sql/pandas/functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def calculate(iterator: Iterator[pa.Array]) -> Iterator[pa.Array]:
322322
pyspark.sql.GroupedData.applyInArrow
323323
pyspark.sql.PandasCogroupedOps.applyInArrow
324324
pyspark.sql.UDFRegistration.register
325+
pyspark.sql.GroupedData.applyInPandas
325326
"""
326327
require_minimum_pyarrow_version()
327328

@@ -346,6 +347,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
346347
.. versionchanged:: 4.0.0
347348
Supports keyword-arguments in SCALAR and GROUPED_AGG type.
348349
350+
.. versionchanged:: 4.1.0
351+
Supports iterator API in GROUPED_MAP type.
352+
349353
Parameters
350354
----------
351355
f : function, optional

python/pyspark/sql/pandas/typehints.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,22 @@ def infer_group_pandas_eval_type(
456456
if is_iterator_dataframe or is_iterator_dataframe_with_keys:
457457
return PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
458458

459-
# Default to non-iterator (standard grouped map)
460-
return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
459+
# pd.DataFrame -> pd.DataFrame
460+
is_dataframe = (
461+
len(parameters_sig) == 1
462+
and parameters_sig[0] == pd.DataFrame
463+
and return_annotation == pd.DataFrame
464+
)
465+
# Tuple[Any, ...], pd.DataFrame -> pd.DataFrame
466+
is_dataframe_with_keys = (
467+
len(parameters_sig) == 2
468+
and parameters_sig[1] == pd.DataFrame
469+
and return_annotation == pd.DataFrame
470+
)
471+
if is_dataframe or is_dataframe_with_keys:
472+
return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
473+
474+
return None
461475

462476

463477
def infer_group_pandas_eval_type_from_func(

python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
pandas_requirement_message,
2727
pyarrow_requirement_message,
2828
)
29-
from pyspark.sql.pandas.typehints import infer_eval_type
29+
from pyspark.sql.pandas.typehints import infer_eval_type, infer_group_pandas_eval_type
3030
from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
3131
from pyspark.sql import Row
32+
from pyspark.util import PythonEvalType
3233

3334
if have_pandas:
3435
import pandas as pd
@@ -186,6 +187,49 @@ def func() -> float:
186187
PandasUDFType.GROUPED_AGG,
187188
)
188189

190+
def test_type_annotation_group_map(self):
191+
# pd.DataFrame -> pd.DataFrame
192+
def func(col: pd.DataFrame) -> pd.DataFrame:
193+
pass
194+
195+
self.assertEqual(
196+
infer_group_pandas_eval_type(signature(func), get_type_hints(func)),
197+
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
198+
)
199+
200+
# Tuple[Any, ...], pd.DataFrame -> pd.DataFrame
201+
def func(key: Tuple, col: pd.DataFrame) -> pd.DataFrame:
202+
pass
203+
204+
self.assertEqual(
205+
infer_group_pandas_eval_type(signature(func), get_type_hints(func)),
206+
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
207+
)
208+
209+
# Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
210+
def func(col: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
211+
pass
212+
213+
self.assertEqual(
214+
infer_group_pandas_eval_type(signature(func), get_type_hints(func)),
215+
PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
216+
)
217+
218+
# Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
219+
def func(key: Tuple, col: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
220+
pass
221+
222+
self.assertEqual(
223+
infer_group_pandas_eval_type(signature(func), get_type_hints(func)),
224+
PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
225+
)
226+
227+
# Should return None for unsupported signatures
228+
def func(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
229+
pass
230+
231+
self.assertEqual(infer_group_pandas_eval_type(signature(func), get_type_hints(func)), None)
232+
189233
def test_type_annotation_negative(self):
190234
def func(col: str) -> pd.Series:
191235
pass

0 commit comments

Comments
 (0)