From 1cdbe7d59b1510dc2b8418fb44da7f65a301c3b8 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 21 Oct 2025 17:22:57 +0900 Subject: [PATCH 1/3] Prototype of runtime profiler --- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../api/python/PythonWorkerFactory.scala | 81 ++++++---- python/pyspark/daemon.py | 33 +++- python/pyspark/worker.py | 23 +++ ...pache.spark.sql.sources.DataSourceRegister | 1 + .../PythonProfileMicroBatchStream.scala | 149 ++++++++++++++++++ .../sources/PythonProfileSourceProvider.scala | 67 ++++++++ 7 files changed, 324 insertions(+), 32 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileMicroBatchStream.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileSourceProvider.scala diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 796dbf4b6d5f8..44950a024953b 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -94,7 +94,7 @@ class SparkEnv ( */ private case class PythonWorkersKey( pythonExec: String, workerModule: String, daemonModule: String, envVars: Map[String, String]) - private val pythonWorkers = mutable.HashMap[PythonWorkersKey, PythonWorkerFactory]() + private[sql] val pythonWorkers = mutable.HashMap[PythonWorkersKey, PythonWorkerFactory]() // A general, soft-reference map for metadata needed during HadoopRDD split computation // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index e02f10cc3fe69..90b3368769704 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -38,36 +38,43 @@ import org.apache.spark.internal.config.Python.PYTHON_FACTORY_IDLE_WORKER_MAX_PO import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util.{RedirectThread, Utils} -case class PythonWorker(channel: SocketChannel) { - - private[this] var selectorOpt: Option[Selector] = None - private[this] var selectionKeyOpt: Option[SelectionKey] = None - - def selector: Selector = selectorOpt.orNull - def selectionKey: SelectionKey = selectionKeyOpt.orNull - - private def closeSelector(): Unit = { - selectionKeyOpt.foreach(_.cancel()) - selectorOpt.foreach(_.close()) +case class PythonWorker( + channel: SocketChannel, + extraChannel: Option[SocketChannel] = None) { + + private[this] var selectors: Seq[Selector] = Seq.empty + private[this] var selectionKeys: Seq[SelectionKey] = Seq.empty + + private def closeSelectors(): Unit = { + selectionKeys.foreach(_.cancel()) + selectors.foreach(_.close()) + selectors = Seq.empty + selectionKeys = Seq.empty } def refresh(): this.type = synchronized { - closeSelector() - if (channel.isBlocking) { - selectorOpt = None - selectionKeyOpt = None - } else { - val selector = Selector.open() - selectorOpt = Some(selector) - selectionKeyOpt = - Some(channel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE)) - } + closeSelectors() + + val channels = Seq(Some(channel), extraChannel).flatten + val (selList, keyList) = channels.map { ch => + if (ch.isBlocking) { + (None, None) + } else { + val selector = Selector.open() + val key = ch.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE) + (Some(selector), Some(key)) + } + }.unzip + + selectors = selList.flatten + selectionKeys = keyList.flatten this } def stop(): Unit = synchronized { - closeSelector() + closeSelectors() Option(channel).foreach(_.close()) + extraChannel.foreach(_.close()) } } @@ -129,6 +136,10 @@ private[spark] class PythonWorkerFactory( envVars.getOrElse("PYTHONPATH", ""), sys.env.getOrElse("PYTHONPATH", "")) + def getAllDaemonWorkers: Seq[(PythonWorker, ProcessHandle)] = self.synchronized { + daemonWorkers.filter { case (_, handle) => handle.isAlive}.toSeq + } + def create(): (PythonWorker, Option[ProcessHandle]) = { if (useDaemon) { self.synchronized { @@ -163,22 +174,36 @@ private[spark] class PythonWorkerFactory( private def createThroughDaemon(): (PythonWorker, Option[ProcessHandle]) = { def createWorker(): (PythonWorker, Option[ProcessHandle]) = { - val socketChannel = if (isUnixDomainSock) { + val mainChannel = if (isUnixDomainSock) { SocketChannel.open(UnixDomainSocketAddress.of(daemonSockPath)) } else { SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort)) } + + val extraChannel = if (envVars.getOrElse("PYSPARK_RUNTIME_PROFILE", "false").toBoolean) { + if (isUnixDomainSock) { + Some(SocketChannel.open(UnixDomainSocketAddress.of(daemonSockPath))) + } else { + Some(SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort))) + } + } else { + None + } + // These calls are blocking. - val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt() + val pid = new DataInputStream(Channels.newInputStream(mainChannel)).readInt() if (pid < 0) { throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } val processHandle = ProcessHandle.of(pid).orElseThrow( () => new IllegalStateException("Python daemon failed to launch worker.") ) - authHelper.authToServer(socketChannel) - socketChannel.configureBlocking(false) - val worker = PythonWorker(socketChannel) + + authHelper.authToServer(mainChannel) + mainChannel.configureBlocking(false) + extraChannel.foreach(_.configureBlocking(false)) + + val worker = PythonWorker(mainChannel, extraChannel) daemonWorkers.put(worker, processHandle) (worker.refresh(), Some(processHandle)) } @@ -271,7 +296,7 @@ private[spark] class PythonWorkerFactory( if (!blockingMode) { socketChannel.configureBlocking(false) } - val worker = PythonWorker(socketChannel) + val worker = PythonWorker(socketChannel, None) self.synchronized { simpleWorkers.put(worker, workerProcess) } diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index ca33ce2c39ef7..ce0899437d852 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -47,7 +47,7 @@ def compute_real_exit_code(exit_code): return 1 -def worker(sock, authenticated): +def worker(sock, sock2, authenticated): """ Called by a worker process after the fork(). """ @@ -64,6 +64,9 @@ def worker(sock, authenticated): buffer_size = int(os.environ.get("SPARK_BUFFER_SIZE", 65536)) infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size) outfile = os.fdopen(os.dup(sock.fileno()), "wb", buffer_size) + outfile2 = None + if sock2 is not None: + outfile2 = os.fdopen(os.dup(sock2.fileno()), "wb", buffer_size) if not authenticated: client_secret = UTF8Deserializer().loads(infile) @@ -74,11 +77,16 @@ def worker(sock, authenticated): write_with_length("err".encode("utf-8"), outfile) outfile.flush() sock.close() + if sock2 is not None: + sock2.close() return 1 exit_code = 0 try: - worker_main(infile, outfile) + if sock2 is not None: + worker_main(infile, (outfile, outfile2)) + else: + worker_main(infile, outfile) except SystemExit as exc: exit_code = compute_real_exit_code(exc.code) finally: @@ -94,6 +102,7 @@ def manager(): os.setpgid(0, 0) is_unix_domain_sock = os.environ.get("PYTHON_UNIX_DOMAIN_ENABLED", "false").lower() == "true" + is_python_runtime_profile = os.environ.get("PYSPARK_RUNTIME_PROFILE", "false").lower() == "true" socket_path = None # Create a listening socket on the loopback interface @@ -173,6 +182,15 @@ def handle_sigterm(*args): continue raise + sock2 = None + if is_python_runtime_profile: + try: + sock2, _ = listen_sock.accept() + except OSError as e: + if e.errno == EINTR: + continue + raise + # Launch a worker process try: pid = os.fork() @@ -186,6 +204,13 @@ def handle_sigterm(*args): outfile.flush() outfile.close() sock.close() + + if sock2 is not None: + outfile = sock2.makefile(mode="wb") + write_int(e.errno, outfile) # Signal that the fork failed + outfile.flush() + outfile.close() + sock2.close() continue if pid == 0: @@ -217,7 +242,7 @@ def handle_sigterm(*args): or False ) while True: - code = worker(sock, authenticated) + code = worker(sock, sock2, authenticated) if code == 0: authenticated = True if not reuse or code: @@ -225,6 +250,8 @@ def handle_sigterm(*args): try: while sock.recv(1024): pass + while sock2 is not None and sock2.recv(1024): + pass except Exception: pass break diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8ab32b4312bb7..286bcf914235e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -18,6 +18,8 @@ """ Worker that receives input from Piped RDD. """ +import pickle +import threading import itertools import os import sys @@ -45,6 +47,7 @@ read_bool, write_long, read_int, + write_with_length, SpecialLengths, CPickleSerializer, BatchedSerializer, @@ -3167,7 +3170,27 @@ def func(_, it): return func, None, ser, ser +def write_profile(outfile): + import yappi + + while True: + stats = [] + for thread in yappi.get_thread_stats(): + data = list(yappi.get_func_stats(ctx_id=thread.id)) + stats.extend([{str(k): str(v) for k, v in d.items()} for d in data]) + pickled = pickle.dumps(stats) + write_with_length(pickled, outfile) + time.sleep(1) + + def main(infile, outfile): + if isinstance(outfile, tuple): + import yappi + + outfile, outfile2 = outfile + yappi.start() + threading.Thread(target=write_profile, args=(outfile2,), daemon=True).start() + faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None) try: diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index b628c753a7676..1e40a1c95d0b5 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -26,6 +26,7 @@ org.apache.spark.sql.execution.datasources.xml.XmlFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +org.apache.spark.sql.execution.streaming.sources.PythonProfileSourceProvider org.apache.spark.sql.execution.datasources.binaryfile.BinaryFileFormat org.apache.spark.sql.execution.streaming.sources.RatePerMicroBatchProvider org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileMicroBatchStream.scala new file mode 100644 index 0000000000000..607d536e78390 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileMicroBatchStream.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.DataInputStream +import java.nio.channels.Channels +import java.util.concurrent.atomic.AtomicBoolean +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ + +import net.razorvine.pickle.Unpickler + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset} +import org.apache.spark.sql.execution.streaming.runtime.LongOffset +import org.apache.spark.unsafe.types.UTF8String + + +class PythonProfileMicroBatchStream + extends MicroBatchStream with Logging { + + @GuardedBy("this") + private var readThread: Thread = null + + @GuardedBy("this") + private val batches = new ListBuffer[java.util.List[java.util.Map[String, String]]] + + @GuardedBy("this") + private var currentOffset: LongOffset = LongOffset(-1L) + + @GuardedBy("this") + private var lastOffsetCommitted: LongOffset = LongOffset(-1L) + + private val initialized: AtomicBoolean = new AtomicBoolean(false) + + private def initialize(): Unit = synchronized { + readThread = new Thread(s"PythonProfileMicroBatchStream") { + setDaemon(true) + + override def run(): Unit = { + val unpickler = new Unpickler + val extraChannel = SparkEnv.get.pythonWorkers.values + .head.getAllDaemonWorkers.map(_._1.extraChannel).head + extraChannel.foreach { s => + val inputStream = new DataInputStream(Channels.newInputStream(s)) + while (true) { + val len = inputStream.readInt() + val buf = new Array[Byte](len) + var totalRead = 0 + while (totalRead < len) { + val readNow = inputStream.read(buf, totalRead, len - totalRead) + assert(readNow != -1) + totalRead += readNow + } + currentOffset += 1 + batches.append( + unpickler.loads(buf).asInstanceOf[java.util.List[java.util.Map[String, String]]]) + } + } + } + } + readThread.start() + } + + override def initialOffset(): Offset = LongOffset(-1L) + + override def latestOffset(): Offset = currentOffset + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } + + override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = { + val startOrdinal = start.asInstanceOf[LongOffset].offset.toInt + 1 + val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 + + val rawList = synchronized { + if (initialized.compareAndSet(false, true)) { + initialize() + } + + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 + val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + batches.slice(sliceStart, sliceEnd) + } + + Array(PythonProfileInputPartition(rawList)) + } + + override def createReaderFactory(): PartitionReaderFactory = + (partition: InputPartition) => { + val stats = partition.asInstanceOf[PythonProfileInputPartition].stats + new PartitionReader[InternalRow] { + private var currentIdx = -1 + + override def next(): Boolean = { + currentIdx += 1 + currentIdx < stats.size + } + + override def get(): InternalRow = { + InternalRow.fromSeq(stats(currentIdx).asScala.toSeq.map(_.asScala)) + } + + override def close(): Unit = {} + } + } + + override def commit(end: Offset): Unit = synchronized { + val newOffset = end.asInstanceOf[LongOffset] + + val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt + + if (offsetDiff < 0) { + throw new IllegalStateException( + s"Offsets committed out of order: $lastOffsetCommitted followed by $end") + } + + batches.dropInPlace(offsetDiff) + lastOffsetCommitted = newOffset + } + + override def toString: String = s"PythonProfile" + + override def stop(): Unit = { } +} + +case class PythonProfileInputPartition( + stats: ListBuffer[java.util.List[java.util.Map[String, String]]]) extends InputPartition diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileSourceProvider.scala new file mode 100644 index 0000000000000..c3dba019f30d0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileSourceProvider.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} +import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} +import org.apache.spark.sql.internal.connector.SimpleTableProvider +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class PythonProfileSourceProvider extends SimpleTableProvider with DataSourceRegister with Logging { + override def getTable(options: CaseInsensitiveStringMap): Table = new PythonProfileTable() + + override def shortName(): String = "socket" +} + +class PythonProfileTable() + extends Table with SupportsRead { + + override def name(): String = s"PythonProfile" + + override def schema(): StructType = { + StructType(StructField("key", MapType(keyType = StringType, valueType = StringType)) :: Nil) + } + + override def capabilities(): util.Set[TableCapability] = { + util.EnumSet.of(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ) + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = () => new Scan { + override def readSchema(): StructType = { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + columns.asSchema + } + + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { + new PythonProfileMicroBatchStream() + } + + override def toContinuousStream(checkpointLocation: String): ContinuousStream = { + throw new UnsupportedOperationException() + } + + override def columnarSupportMode(): Scan.ColumnarSupportMode = + Scan.ColumnarSupportMode.UNSUPPORTED + } +} From 39fb9dcb4a8a2c9f914b5a54f2005b2175b39e1c Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 22 Oct 2025 16:59:05 +0900 Subject: [PATCH 2/3] working version --- .../scala/org/apache/spark/SparkEnv.scala | 5 +- .../spark/api/python/PythonRunner.scala | 1 + .../api/python/PythonWorkerFactory.scala | 46 +++++++++---------- python/pyspark/worker.py | 1 + .../PythonProfileMicroBatchStream.scala | 7 +-- .../sources/PythonProfileSourceProvider.scala | 5 +- .../sql/execution/python/PythonUDFSuite.scala | 15 +++++- 7 files changed, 47 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 44950a024953b..df60940f86c09 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -92,9 +92,10 @@ class SparkEnv ( * @param daemonModule The daemon module name to reuse the worker, e.g., "pyspark.daemon". * @param envVars The environment variables for the worker. */ - private case class PythonWorkersKey( + case class PythonWorkersKey( pythonExec: String, workerModule: String, daemonModule: String, envVars: Map[String, String]) - private[sql] val pythonWorkers = mutable.HashMap[PythonWorkersKey, PythonWorkerFactory]() + val pythonWorkers: mutable.Map[PythonWorkersKey, PythonWorkerFactory] = + mutable.HashMap[PythonWorkersKey, PythonWorkerFactory]() // A general, soft-reference map for metadata needed during HadoopRDD split computation // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 8e5b7ef001b84..5b0f9a4970fe0 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -281,6 +281,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } // allow the user to set the batch size for the BatchedSerializer on UDFs envVars.put("PYTHON_UDF_BATCH_SIZE", batchSizeForPythonUDF.toString) + envVars.put("PYSPARK_RUNTIME_PROFILE", true.toString) envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 90b3368769704..ea1aa7f987134 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -42,37 +42,33 @@ case class PythonWorker( channel: SocketChannel, extraChannel: Option[SocketChannel] = None) { - private[this] var selectors: Seq[Selector] = Seq.empty - private[this] var selectionKeys: Seq[SelectionKey] = Seq.empty - - private def closeSelectors(): Unit = { - selectionKeys.foreach(_.cancel()) - selectors.foreach(_.close()) - selectors = Seq.empty - selectionKeys = Seq.empty - } + private[this] var selectorOpt: Option[Selector] = None + private[this] var selectionKeyOpt: Option[SelectionKey] = None - def refresh(): this.type = synchronized { - closeSelectors() + def selector: Selector = selectorOpt.orNull + def selectionKey: SelectionKey = selectionKeyOpt.orNull - val channels = Seq(Some(channel), extraChannel).flatten - val (selList, keyList) = channels.map { ch => - if (ch.isBlocking) { - (None, None) - } else { - val selector = Selector.open() - val key = ch.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE) - (Some(selector), Some(key)) - } - }.unzip + private def closeSelector(): Unit = { + selectionKeyOpt.foreach(_.cancel()) + selectorOpt.foreach(_.close()) + } - selectors = selList.flatten - selectionKeys = keyList.flatten + def refresh(): this.type = synchronized { + closeSelector() + if (channel.isBlocking) { + selectorOpt = None + selectionKeyOpt = None + } else { + val selector = Selector.open() + selectorOpt = Some(selector) + selectionKeyOpt = + Some(channel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE)) + } this } def stop(): Unit = synchronized { - closeSelectors() + closeSelector() Option(channel).foreach(_.close()) extraChannel.foreach(_.close()) } @@ -201,7 +197,7 @@ private[spark] class PythonWorkerFactory( authHelper.authToServer(mainChannel) mainChannel.configureBlocking(false) - extraChannel.foreach(_.configureBlocking(false)) + extraChannel.foreach(_.configureBlocking(true)) val worker = PythonWorker(mainChannel, extraChannel) daemonWorkers.put(worker, processHandle) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 286bcf914235e..0cdc536459202 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -3180,6 +3180,7 @@ def write_profile(outfile): stats.extend([{str(k): str(v) for k, v in d.items()} for d in data]) pickled = pickle.dumps(stats) write_with_length(pickled, outfile) + outfile.flush() time.sleep(1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileMicroBatchStream.scala index 607d536e78390..4771e6cbbede0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileMicroBatchStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileMicroBatchStream.scala @@ -29,11 +29,10 @@ import net.razorvine.pickle.Unpickler import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset} import org.apache.spark.sql.execution.streaming.runtime.LongOffset -import org.apache.spark.unsafe.types.UTF8String class PythonProfileMicroBatchStream @@ -119,7 +118,9 @@ class PythonProfileMicroBatchStream } override def get(): InternalRow = { - InternalRow.fromSeq(stats(currentIdx).asScala.toSeq.map(_.asScala)) + InternalRow.fromSeq( + CatalystTypeConverters.convertToCatalyst( + stats(currentIdx).asScala.toSeq.map(_.asScala)) :: Nil) } override def close(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileSourceProvider.scala index c3dba019f30d0..21cc14d2e9170 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PythonProfileSourceProvider.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, MapType, StringType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap class PythonProfileSourceProvider extends SimpleTableProvider with DataSourceRegister with Logging { @@ -40,7 +40,8 @@ class PythonProfileTable() override def name(): String = s"PythonProfile" override def schema(): StructType = { - StructType(StructField("key", MapType(keyType = StringType, valueType = StringType)) :: Nil) + StructType(StructField( + "data", ArrayType(MapType(keyType = StringType, valueType = StringType))) :: Nil) } override def capabilities(): util.Set[TableCapability] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 9b40226c2049b..2d9a060e89005 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, Dataset, IntegratedUDFTestUtils, QueryTest, Row} import org.apache.spark.sql.functions.{array, avg, col, count, transform} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.LongType @@ -156,4 +156,17 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(0, 0.0, 0)) } + + test("profile test") { + assume(shouldTestPythonUDFs) + val df2 = base.groupBy(pythonTestUDF(base("a") + 1)) + .agg(pythonTestUDF(base("a") + 1), pythonTestUDF(count(base("b")))) + df2.collect() + val df3 = spark.readStream.format( + "org.apache.spark.sql.execution.streaming.sources.PythonProfileSourceProvider").load() + val q = df3.writeStream.foreachBatch((df: Dataset[Row], batchId: Long) => { + df.show(truncate = false) + }).start() + q.awaitTermination() + } } From b2d1e01ef73b346073b0a5247b1e1dd0c191451b Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 24 Oct 2025 12:25:15 +0900 Subject: [PATCH 3/3] fixup --- .../sql/execution/python/PythonUDFSuite.scala | 60 ++++++++++++++++++- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 2d9a060e89005..ed15c661f3a67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -162,11 +162,65 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { val df2 = base.groupBy(pythonTestUDF(base("a") + 1)) .agg(pythonTestUDF(base("a") + 1), pythonTestUDF(count(base("b")))) df2.collect() + // scalastyle:off println val df3 = spark.readStream.format( "org.apache.spark.sql.execution.streaming.sources.PythonProfileSourceProvider").load() - val q = df3.writeStream.foreachBatch((df: Dataset[Row], batchId: Long) => { - df.show(truncate = false) - }).start() + val q = df3.writeStream.foreachBatch { (df: Dataset[Row], batchId: Long) => + // Clear the console before printing the next batch + print("\u001b[2J") // Clear entire screen + print("\u001b[H") // Move cursor to top-left corner + + df.collect().foreach { row => + val listOfMaps = row.get(0).asInstanceOf[collection.mutable.ArraySeq[Map[String, String]]] + + // Group by thread id and name + val groupedByThread = listOfMaps.groupBy(m => + (m.getOrElse("10", "unknown"), m.getOrElse("11", "unknown")) + ) + + groupedByThread.toSeq.sortBy(_._1._1).foreach { case ((threadId, threadName), funcsAll) => + println(s"\nFunction stats for (Thread) ($threadId - $threadName)\n") + + // Only top 5 rows for readability + val funcs = funcsAll.take(5) + + // Dynamic column widths + val nameWidth = + (funcs.map(_.getOrElse("15", "").length).maxOption.getOrElse(4) max "name".length) + 2 + val ncallWidth = + (funcs.map(_.getOrElse("3", "").length).maxOption.getOrElse(5) max "ncall".length) + 2 + val timeWidth = 12 + + val fmt = + s"%-${nameWidth}s %${ncallWidth}s %${timeWidth}s %${timeWidth}s %${timeWidth}s" + + println(fmt.format("name", "ncall", "tsub", "ttot", "tavg")) + println("-" * (nameWidth + ncallWidth + timeWidth * 3 + 4)) + + funcs.foreach { m => + val name = { + val full = m.getOrElse("15", "") + if (full.length > nameWidth) "..." + full.takeRight(nameWidth - 3) else full + } + + val ncall = m.getOrElse("3", "") + val tsub = formatDouble(m.getOrElse("7", "")) + val ttot = formatDouble(m.getOrElse("6", "")) + val tavg = formatDouble(m.getOrElse("14", "")) + + println(fmt.format(name, ncall, tsub, ttot, tavg)) + } + } + } + + // Helper to format numbers as fixed-point + def formatDouble(value: String): String = { + try f"${value.toDouble}%.6f" catch { + case _: Throwable => value + } + } + }.start() q.awaitTermination() + // scalastyle:on println } }