Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#
import unittest
import logging
from typing import Iterator, Optional

from pyspark.errors import PySparkAttributeError
Expand All @@ -23,6 +24,7 @@
from pyspark.sql.types import Row, StructType, StructField, IntegerType
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pyarrow, pyarrow_requirement_message
from pyspark.testing import assertDataFrameEqual
from pyspark.util import is_remote_only

if have_pyarrow:
import pyarrow as pa
Expand Down Expand Up @@ -1685,6 +1687,55 @@ def eval(
)
assertDataFrameEqual(sql_result_df2, expected_df2)

@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_arrow_udtf_with_logging(self):
import pyarrow as pa

@arrow_udtf(returnType="id bigint, doubled bigint")
class TestArrowUDTFWithLogging:
def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]:
assert isinstance(
table_data, pa.RecordBatch
), f"Expected pa.RecordBatch, got {type(table_data)}"

logger = logging.getLogger("test_arrow_udtf")
logger.warning(f"arrow udtf: {table_data.to_pydict()}")

# Convert record batch to table
table = pa.table(table_data)

# Get the id column and create doubled values
id_column = table.column("id")
doubled_values = pa.compute.multiply(id_column, pa.scalar(2))

yield pa.table({"id": id_column, "doubled": doubled_values})

with self.sql_conf(
{
"spark.sql.execution.arrow.maxRecordsPerBatch": "3",
"spark.sql.pyspark.worker.logging.enabled": "true",
}
):
assertDataFrameEqual(
TestArrowUDTFWithLogging(self.spark.range(9, numPartitions=2).asTable()),
[Row(id=i, doubled=i * 2) for i in range(9)],
)

logs = self.spark.table("system.session.python_worker_logs")

assertDataFrameEqual(
logs.select("level", "msg", "context", "logger"),
[
Row(
level="WARNING",
msg=f"arrow udtf: {dict(id=lst)}",
context={"class_name": "TestArrowUDTFWithLogging", "func_name": "eval"},
logger="test_arrow_udtf",
)
for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
],
)


class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase):
pass
Expand Down
34 changes: 34 additions & 0 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import shutil
import tempfile
import unittest
import logging
import time
from dataclasses import dataclass
from typing import Iterator, Optional
Expand Down Expand Up @@ -73,6 +74,7 @@
pyarrow_requirement_message,
ReusedSQLTestCase,
)
from pyspark.util import is_remote_only


class BaseUDTFTestsMixin:
Expand Down Expand Up @@ -3059,6 +3061,38 @@ def eval(self, b):
result = BinaryTypeUDTF(lit(b"test")).collect()
self.assertEqual(result[0]["type_name"], expected_type)

@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_udtf_with_logging(self):
@udtf(returnType="a: int, b: int")
class TestUDTFWithLogging:
def eval(self, x: int):
logger = logging.getLogger("test_udtf")
logger.warning(f"udtf with logging: {x}")
yield x * 2, x + 10

with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
assertDataFrameEqual(
self.spark.createDataFrame([(5,), (10,)], ["x"]).lateralJoin(
TestUDTFWithLogging(col("x").outer())
),
[Row(x=x, a=x * 2, b=x + 10) for x in [5, 10]],
)

logs = self.spark.table("system.session.python_worker_logs")

assertDataFrameEqual(
logs.select("level", "msg", "context", "logger"),
[
Row(
level="WARNING",
msg=f"udtf with logging: {x}",
context={"class_name": "TestUDTFWithLogging", "func_name": "eval"},
logger="test_udtf",
)
for x in [5, 10]
],
)


class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ case class ArrowEvalPythonUDTFExec(
private val largeVarTypes = conf.arrowUseLargeVarTypes
private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
private[this] val sessionUUID = {
Option(session).collect {
case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
session.sessionUUID
}
}

override protected def evaluate(
argMetas: Array[ArgumentMetadata],
Expand All @@ -75,7 +81,8 @@ case class ArrowEvalPythonUDTFExec(
largeVarTypes,
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID).compute(batchIter, context.partitionId(), context)
jobArtifactUUID,
sessionUUID).compute(batchIter, context.partitionId(), context)

columnarBatchIter.map { batch =>
// UDTF returns a StructType column in ColumnarBatch. Flatten the columnar batch here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.python

import java.io.DataOutputStream
import java.util

import org.apache.spark.api.python._
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -40,7 +41,8 @@ class ArrowPythonUDTFRunner(
protected override val largeVarTypes: Boolean,
protected override val workerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType, Array(argMetas.map(_.offset)),
jobArtifactUUID, pythonMetrics)
Expand All @@ -65,6 +67,13 @@ class ArrowPythonUDTFRunner(
PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas)
}

override val envVars: util.Map[String, String] = {
val envVars = new util.HashMap(funcs.head.funcs.head.envVars)
sessionUUID.foreach { uuid =>
envVars.put("PYSPARK_SPARK_SESSION_UUID", uuid)
}
envVars
}
override val pythonExec: String =
SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
funcs.head.funcs.head.pythonExec)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ case class BatchEvalPythonUDTFExec(
extends EvalPythonUDTFExec with PythonSQLMetrics {

private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
private[this] val sessionUUID = {
Option(session).collect {
case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
session.sessionUUID
}
}

/**
* Evaluates a Python UDTF. It computes the results using the PythonUDFRunner, and returns
Expand All @@ -70,7 +76,7 @@ case class BatchEvalPythonUDTFExec(

// Output iterator for results from Python.
val outputIterator =
new PythonUDTFRunner(udtf, argMetas, pythonMetrics, jobArtifactUUID)
new PythonUDTFRunner(udtf, argMetas, pythonMetrics, jobArtifactUUID, sessionUUID)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
Expand Down Expand Up @@ -99,10 +105,12 @@ class PythonUDTFRunner(
udtf: PythonUDTF,
argMetas: Array[ArgumentMetadata],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
extends BasePythonUDFRunner(
Seq((ChainedPythonFunctions(Seq(udtf.func)), udtf.resultId.id)),
PythonEvalType.SQL_TABLE_UDF, Array(argMetas.map(_.offset)), pythonMetrics, jobArtifactUUID) {
PythonEvalType.SQL_TABLE_UDF, Array(argMetas.map(_.offset)), pythonMetrics,
jobArtifactUUID, sessionUUID) {

// Overriding here to NOT use the same value of UDF config in UDTF.
override val bufferSize: Int = SparkEnv.get.conf.get(BUFFER_SIZE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ abstract class BasePythonUDFRunner(
argOffsets: Array[Array[Int]],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String] = None)
sessionUUID: Option[String])
extends BasePythonRunner[Array[Byte], Array[Byte]](
funcs.map(_._1), evalType, argOffsets, jobArtifactUUID, pythonMetrics) {

Expand Down