|
26 | 26 | pandas_requirement_message, |
27 | 27 | pyarrow_requirement_message, |
28 | 28 | ) |
29 | | -from pyspark.sql.pandas.typehints import infer_eval_type |
| 29 | +from pyspark.sql.pandas.typehints import infer_eval_type, infer_group_pandas_eval_type |
30 | 30 | from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType |
31 | 31 | from pyspark.sql import Row |
| 32 | +from pyspark.util import PythonEvalType |
32 | 33 |
|
33 | 34 | if have_pandas: |
34 | 35 | import pandas as pd |
@@ -186,6 +187,49 @@ def func() -> float: |
186 | 187 | PandasUDFType.GROUPED_AGG, |
187 | 188 | ) |
188 | 189 |
|
| 190 | + def test_type_annotation_group_map(self): |
| 191 | + # pd.DataFrame -> pd.DataFrame |
| 192 | + def func(col: pd.DataFrame) -> pd.DataFrame: |
| 193 | + pass |
| 194 | + |
| 195 | + self.assertEqual( |
| 196 | + infer_group_pandas_eval_type(signature(func), get_type_hints(func)), |
| 197 | + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, |
| 198 | + ) |
| 199 | + |
| 200 | + # Tuple[Any, ...], pd.DataFrame -> pd.DataFrame |
| 201 | + def func(key: Tuple, col: pd.DataFrame) -> pd.DataFrame: |
| 202 | + pass |
| 203 | + |
| 204 | + self.assertEqual( |
| 205 | + infer_group_pandas_eval_type(signature(func), get_type_hints(func)), |
| 206 | + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, |
| 207 | + ) |
| 208 | + |
| 209 | + # Iterator[pd.DataFrame] -> Iterator[pd.DataFrame] |
| 210 | + def func(col: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]: |
| 211 | + pass |
| 212 | + |
| 213 | + self.assertEqual( |
| 214 | + infer_group_pandas_eval_type(signature(func), get_type_hints(func)), |
| 215 | + PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, |
| 216 | + ) |
| 217 | + |
| 218 | + # Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame] |
| 219 | + def func(key: Tuple, col: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]: |
| 220 | + pass |
| 221 | + |
| 222 | + self.assertEqual( |
| 223 | + infer_group_pandas_eval_type(signature(func), get_type_hints(func)), |
| 224 | + PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, |
| 225 | + ) |
| 226 | + |
| 227 | + # Should return None for unsupported signatures |
| 228 | + def func(col: Iterator[pd.Series]) -> Iterator[pd.Series]: |
| 229 | + pass |
| 230 | + |
| 231 | + self.assertEqual(infer_group_pandas_eval_type(signature(func), get_type_hints(func)), None) |
| 232 | + |
189 | 233 | def test_type_annotation_negative(self): |
190 | 234 | def func(col: str) -> pd.Series: |
191 | 235 | pass |
|
0 commit comments