Skip to content

Commit 319134e

Browse files
ueshinYicong-Huang
authored andcommitted
[SPARK-53976][PYTHON] Support logging in Pandas/Arrow UDFs
Supports logging in Pandas/Arrow UDFs. The basic logging infrastructure was introduced in apache#52689, and other UDF types should also support logging. Here adding support for Pandas and Arrow UDFs. Yes, the logging feature will be available in Pandas/Arrow UDFs. Added the related tests. No. Closes apache#52785 from ueshin/issues/SPARK-53976/pandas_arrow_udfs. Authored-by: Takuya Ueshin <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent f874541 commit 319134e

24 files changed

+700
-28
lines changed

python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import time
1919
import unittest
20+
import logging
2021

2122
from pyspark.errors import PythonException
2223
from pyspark.sql import Row
@@ -26,6 +27,8 @@
2627
have_pyarrow,
2728
pyarrow_requirement_message,
2829
)
30+
from pyspark.testing.utils import assertDataFrameEqual
31+
from pyspark.util import is_remote_only
2932

3033
if have_pyarrow:
3134
import pyarrow as pa
@@ -367,6 +370,49 @@ def test_negative_and_zero_batch_size(self):
367370
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
368371
CogroupedMapInArrowTestsMixin.test_apply_in_arrow(self)
369372

373+
@unittest.skipIf(is_remote_only(), "Requires JVM access")
374+
def test_cogroup_apply_in_arrow_with_logging(self):
375+
import pyarrow as pa
376+
377+
def func_with_logging(left, right):
378+
assert isinstance(left, pa.Table)
379+
assert isinstance(right, pa.Table)
380+
logger = logging.getLogger("test_arrow_cogrouped_map")
381+
logger.warning(
382+
"arrow cogrouped map: "
383+
+ f"{dict(v1=left['v1'].to_pylist(), v2=right['v2'].to_pylist())}"
384+
)
385+
return left.join(right, keys="id", join_type="inner")
386+
387+
left_df = self.spark.createDataFrame([(1, 10), (2, 20), (1, 30)], ["id", "v1"])
388+
right_df = self.spark.createDataFrame([(1, 100), (2, 200), (1, 300)], ["id", "v2"])
389+
390+
grouped_left = left_df.groupBy("id")
391+
grouped_right = right_df.groupBy("id")
392+
cogrouped_df = grouped_left.cogroup(grouped_right)
393+
394+
with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
395+
assertDataFrameEqual(
396+
cogrouped_df.applyInArrow(func_with_logging, "id long, v1 long, v2 long"),
397+
[Row(id=1, v1=v1, v2=v2) for v1 in [10, 30] for v2 in [100, 300]]
398+
+ [Row(id=2, v1=20, v2=200)],
399+
)
400+
401+
logs = self.spark.table("system.session.python_worker_logs")
402+
403+
assertDataFrameEqual(
404+
logs.select("level", "msg", "context", "logger"),
405+
[
406+
Row(
407+
level="WARNING",
408+
msg=f"arrow cogrouped map: {dict(v1=v1, v2=v2)}",
409+
context={"func_name": func_with_logging.__name__},
410+
logger="test_arrow_cogrouped_map",
411+
)
412+
for v1, v2 in [([10, 30], [100, 300]), ([20], [200])]
413+
],
414+
)
415+
370416

371417
class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin, ReusedSQLTestCase):
372418
@classmethod

python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import inspect
1818
import os
1919
import time
20+
import logging
2021
from typing import Iterator, Tuple
2122
import unittest
2223

@@ -29,6 +30,8 @@
2930
have_pyarrow,
3031
pyarrow_requirement_message,
3132
)
33+
from pyspark.testing.utils import assertDataFrameEqual
34+
from pyspark.util import is_remote_only
3235

3336
if have_pyarrow:
3437
import pyarrow as pa
@@ -394,6 +397,80 @@ def test_negative_and_zero_batch_size(self):
394397
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
395398
ApplyInArrowTestsMixin.test_apply_in_arrow(self)
396399

400+
@unittest.skipIf(is_remote_only(), "Requires JVM access")
401+
def test_apply_in_arrow_with_logging(self):
402+
import pyarrow as pa
403+
404+
def func_with_logging(group):
405+
assert isinstance(group, pa.Table)
406+
logger = logging.getLogger("test_arrow_grouped_map")
407+
logger.warning(f"arrow grouped map: {group.to_pydict()}")
408+
return group
409+
410+
df = self.spark.range(9).withColumn("value", col("id") * 10)
411+
grouped_df = df.groupBy((col("id") % 2).cast("int"))
412+
413+
with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
414+
assertDataFrameEqual(
415+
grouped_df.applyInArrow(func_with_logging, "id long, value long"),
416+
df,
417+
)
418+
419+
logs = self.spark.table("system.session.python_worker_logs")
420+
421+
assertDataFrameEqual(
422+
logs.select("level", "msg", "context", "logger"),
423+
[
424+
Row(
425+
level="WARNING",
426+
msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v in lst])}",
427+
context={"func_name": func_with_logging.__name__},
428+
logger="test_arrow_grouped_map",
429+
)
430+
for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]]
431+
],
432+
)
433+
434+
@unittest.skipIf(is_remote_only(), "Requires JVM access")
435+
def test_apply_in_arrow_iter_with_logging(self):
436+
import pyarrow as pa
437+
438+
def func_with_logging(group: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
439+
logger = logging.getLogger("test_arrow_grouped_map")
440+
for batch in group:
441+
assert isinstance(batch, pa.RecordBatch)
442+
logger.warning(f"arrow grouped map: {batch.to_pydict()}")
443+
yield batch
444+
445+
df = self.spark.range(9).withColumn("value", col("id") * 10)
446+
grouped_df = df.groupBy((col("id") % 2).cast("int"))
447+
448+
with self.sql_conf(
449+
{
450+
"spark.sql.execution.arrow.maxRecordsPerBatch": 3,
451+
"spark.sql.pyspark.worker.logging.enabled": "true",
452+
}
453+
):
454+
assertDataFrameEqual(
455+
grouped_df.applyInArrow(func_with_logging, "id long, value long"),
456+
df,
457+
)
458+
459+
logs = self.spark.table("system.session.python_worker_logs")
460+
461+
assertDataFrameEqual(
462+
logs.select("level", "msg", "context", "logger"),
463+
[
464+
Row(
465+
level="WARNING",
466+
msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v in lst])}",
467+
context={"func_name": func_with_logging.__name__},
468+
logger="test_arrow_grouped_map",
469+
)
470+
for lst in [[0, 2, 4], [6, 8], [1, 3, 5], [7]]
471+
],
472+
)
473+
397474

398475
class ApplyInArrowTests(ApplyInArrowTestsMixin, ReusedSQLTestCase):
399476
@classmethod

python/pyspark/sql/tests/arrow/test_arrow_map.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import time
1919
import unittest
20+
import logging
2021

2122
from pyspark.sql.utils import PythonException
2223
from pyspark.testing.sqlutils import (
@@ -26,6 +27,9 @@
2627
pandas_requirement_message,
2728
pyarrow_requirement_message,
2829
)
30+
from pyspark.sql import Row
31+
from pyspark.testing.utils import assertDataFrameEqual
32+
from pyspark.util import is_remote_only
2933

3034
if have_pyarrow:
3135
import pyarrow as pa
@@ -221,6 +225,46 @@ def func(iterator):
221225
df = self.spark.range(1)
222226
df.mapInArrow(func, "a int").collect()
223227

228+
@unittest.skipIf(is_remote_only(), "Requires JVM access")
229+
def test_map_in_arrow_with_logging(self):
230+
import pyarrow as pa
231+
232+
def func_with_logging(iterator):
233+
logger = logging.getLogger("test_arrow_map")
234+
for batch in iterator:
235+
assert isinstance(batch, pa.RecordBatch)
236+
logger.warning(f"arrow map: {batch.to_pydict()}")
237+
yield batch
238+
239+
with self.sql_conf(
240+
{
241+
"spark.sql.execution.arrow.maxRecordsPerBatch": "3",
242+
"spark.sql.pyspark.worker.logging.enabled": "true",
243+
}
244+
):
245+
assertDataFrameEqual(
246+
self.spark.range(9, numPartitions=2).mapInArrow(func_with_logging, "id long"),
247+
[Row(id=i) for i in range(9)],
248+
)
249+
250+
logs = self.spark.table("system.session.python_worker_logs")
251+
252+
assertDataFrameEqual(
253+
logs.select("level", "msg", "context", "logger"),
254+
self._expected_logs_for_test_map_in_arrow_with_logging(func_with_logging.__name__),
255+
)
256+
257+
def _expected_logs_for_test_map_in_arrow_with_logging(self, func_name):
258+
return [
259+
Row(
260+
level="WARNING",
261+
msg=f"arrow map: {dict(id=lst)}",
262+
context={"func_name": func_name},
263+
logger="test_arrow_map",
264+
)
265+
for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
266+
]
267+
224268

225269
class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase):
226270
@classmethod
@@ -253,6 +297,17 @@ def setUpClass(cls):
253297
cls.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "3")
254298
cls.spark.conf.set("spark.sql.execution.arrow.maxBytesPerBatch", "10")
255299

300+
def _expected_logs_for_test_map_in_arrow_with_logging(self, func_name):
301+
return [
302+
Row(
303+
level="WARNING",
304+
msg=f"arrow map: {dict(id=[i])}",
305+
context={"func_name": func_name},
306+
logger="test_arrow_map",
307+
)
308+
for i in range(9)
309+
]
310+
256311

257312
class MapInArrowWithOutputArrowBatchSlicingRecordsTests(MapInArrowTests):
258313
@classmethod

python/pyspark/sql/tests/arrow/test_arrow_python_udf.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,6 @@ def test_register_java_function(self):
6060
def test_register_java_udaf(self):
6161
super(ArrowPythonUDFTests, self).test_register_java_udaf()
6262

63-
@unittest.skip(
64-
"TODO(SPARK-53976): Python worker logging is not supported for Arrow Python UDFs."
65-
)
66-
def test_udf_with_logging(self):
67-
super().test_udf_with_logging()
68-
69-
@unittest.skip(
70-
"TODO(SPARK-53976): Python worker logging is not supported for Arrow Python UDFs."
71-
)
72-
def test_multiple_udfs_with_logging(self):
73-
super().test_multiple_udfs_with_logging()
74-
7563
def test_complex_input_types(self):
7664
row = (
7765
self.spark.range(1)

python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
#
1717

1818
import unittest
19+
import logging
1920

2021
from pyspark.sql.functions import arrow_udf, ArrowUDFType
21-
from pyspark.util import PythonEvalType
22+
from pyspark.util import PythonEvalType, is_remote_only
2223
from pyspark.sql import Row
2324
from pyspark.sql.types import (
2425
ArrayType,
@@ -35,6 +36,7 @@
3536
numpy_requirement_message,
3637
have_pyarrow,
3738
pyarrow_requirement_message,
39+
assertDataFrameEqual,
3840
)
3941
from pyspark.testing.sqlutils import ReusedSQLTestCase
4042

@@ -1021,6 +1023,42 @@ def arrow_max(v):
10211023

10221024
self.assertEqual(expected, result)
10231025

1026+
@unittest.skipIf(is_remote_only(), "Requires JVM access")
1027+
def test_grouped_agg_arrow_udf_with_logging(self):
1028+
import pyarrow as pa
1029+
1030+
@arrow_udf("double", ArrowUDFType.GROUPED_AGG)
1031+
def my_grouped_agg_arrow_udf(x):
1032+
assert isinstance(x, pa.Array)
1033+
logger = logging.getLogger("test_grouped_agg_arrow")
1034+
logger.warning(f"grouped agg arrow udf: {len(x)}")
1035+
return pa.compute.sum(x)
1036+
1037+
df = self.spark.createDataFrame(
1038+
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
1039+
)
1040+
1041+
with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
1042+
assertDataFrameEqual(
1043+
df.groupby("id").agg(my_grouped_agg_arrow_udf("v").alias("result")),
1044+
[Row(id=1, result=3.0), Row(id=2, result=18.0)],
1045+
)
1046+
1047+
logs = self.spark.table("system.session.python_worker_logs")
1048+
1049+
assertDataFrameEqual(
1050+
logs.select("level", "msg", "context", "logger"),
1051+
[
1052+
Row(
1053+
level="WARNING",
1054+
msg=f"grouped agg arrow udf: {n}",
1055+
context={"func_name": my_grouped_agg_arrow_udf.__name__},
1056+
logger="test_grouped_agg_arrow",
1057+
)
1058+
for n in [2, 3]
1059+
],
1060+
)
1061+
10241062

10251063
class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
10261064
pass

0 commit comments

Comments
 (0)