-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[DO-NOT-MERGE] Prototype: runtime profiling of Python workers #52679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,7 +38,9 @@ import org.apache.spark.internal.config.Python.PYTHON_FACTORY_IDLE_WORKER_MAX_PO | |
| import org.apache.spark.security.SocketAuthHelper | ||
| import org.apache.spark.util.{RedirectThread, Utils} | ||
|
|
||
| case class PythonWorker(channel: SocketChannel) { | ||
| case class PythonWorker( | ||
| channel: SocketChannel, | ||
| extraChannel: Option[SocketChannel] = None) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's call this profile data channel or something? |
||
|
|
||
| private[this] var selectorOpt: Option[Selector] = None | ||
| private[this] var selectionKeyOpt: Option[SelectionKey] = None | ||
|
|
@@ -68,6 +70,7 @@ case class PythonWorker(channel: SocketChannel) { | |
| def stop(): Unit = synchronized { | ||
| closeSelector() | ||
| Option(channel).foreach(_.close()) | ||
| extraChannel.foreach(_.close()) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -129,6 +132,10 @@ private[spark] class PythonWorkerFactory( | |
| envVars.getOrElse("PYTHONPATH", ""), | ||
| sys.env.getOrElse("PYTHONPATH", "")) | ||
|
|
||
| def getAllDaemonWorkers: Seq[(PythonWorker, ProcessHandle)] = self.synchronized { | ||
| daemonWorkers.filter { case (_, handle) => handle.isAlive}.toSeq | ||
| } | ||
|
|
||
| def create(): (PythonWorker, Option[ProcessHandle]) = { | ||
| if (useDaemon) { | ||
| self.synchronized { | ||
|
|
@@ -163,22 +170,36 @@ private[spark] class PythonWorkerFactory( | |
| private def createThroughDaemon(): (PythonWorker, Option[ProcessHandle]) = { | ||
|
|
||
| def createWorker(): (PythonWorker, Option[ProcessHandle]) = { | ||
| val socketChannel = if (isUnixDomainSock) { | ||
| val mainChannel = if (isUnixDomainSock) { | ||
| SocketChannel.open(UnixDomainSocketAddress.of(daemonSockPath)) | ||
| } else { | ||
| SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort)) | ||
| } | ||
|
|
||
| val extraChannel = if (envVars.getOrElse("PYSPARK_RUNTIME_PROFILE", "false").toBoolean) { | ||
| if (isUnixDomainSock) { | ||
| Some(SocketChannel.open(UnixDomainSocketAddress.of(daemonSockPath))) | ||
| } else { | ||
| Some(SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort))) | ||
| } | ||
| } else { | ||
| None | ||
| } | ||
|
|
||
| // These calls are blocking. | ||
| val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt() | ||
| val pid = new DataInputStream(Channels.newInputStream(mainChannel)).readInt() | ||
| if (pid < 0) { | ||
| throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) | ||
| } | ||
| val processHandle = ProcessHandle.of(pid).orElseThrow( | ||
| () => new IllegalStateException("Python daemon failed to launch worker.") | ||
| ) | ||
| authHelper.authToServer(socketChannel) | ||
| socketChannel.configureBlocking(false) | ||
| val worker = PythonWorker(socketChannel) | ||
|
|
||
| authHelper.authToServer(mainChannel) | ||
| mainChannel.configureBlocking(false) | ||
| extraChannel.foreach(_.configureBlocking(true)) | ||
|
|
||
| val worker = PythonWorker(mainChannel, extraChannel) | ||
| daemonWorkers.put(worker, processHandle) | ||
| (worker.refresh(), Some(processHandle)) | ||
| } | ||
|
|
@@ -271,7 +292,7 @@ private[spark] class PythonWorkerFactory( | |
| if (!blockingMode) { | ||
| socketChannel.configureBlocking(false) | ||
| } | ||
| val worker = PythonWorker(socketChannel) | ||
| val worker = PythonWorker(socketChannel, None) | ||
| self.synchronized { | ||
| simpleWorkers.put(worker, workerProcess) | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,8 @@ | |
| """ | ||
| Worker that receives input from Piped RDD. | ||
| """ | ||
| import pickle | ||
| import threading | ||
| import itertools | ||
| import os | ||
| import sys | ||
|
|
@@ -45,6 +47,7 @@ | |
| read_bool, | ||
| write_long, | ||
| read_int, | ||
| write_with_length, | ||
| SpecialLengths, | ||
| CPickleSerializer, | ||
| BatchedSerializer, | ||
|
|
@@ -3167,7 +3170,28 @@ def func(_, it): | |
| return func, None, ser, ser | ||
|
|
||
|
|
||
| def write_profile(outfile): | ||
| import yappi | ||
|
|
||
| while True: | ||
| stats = [] | ||
| for thread in yappi.get_thread_stats(): | ||
| data = list(yappi.get_func_stats(ctx_id=thread.id)) | ||
| stats.extend([{str(k): str(v) for k, v in d.items()} for d in data]) | ||
| pickled = pickle.dumps(stats) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While pickle? Would JSON maybe make more sense so we can interpert more easily in say the Spark UI in the future? |
||
| write_with_length(pickled, outfile) | ||
| outfile.flush() | ||
| time.sleep(1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know yappi says it's fast but 1 second busy loop seems maybe overkill or should be configurable? |
||
|
|
||
|
|
||
| def main(infile, outfile): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe rename to outputs |
||
| if isinstance(outfile, tuple): | ||
| import yappi | ||
|
|
||
| outfile, outfile2 = outfile | ||
| yappi.start() | ||
| threading.Thread(target=write_profile, args=(outfile2,), daemon=True).start() | ||
|
|
||
| faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) | ||
| tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None) | ||
| try: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.execution.streaming.sources | ||
|
|
||
| import java.io.DataInputStream | ||
| import java.nio.channels.Channels | ||
| import java.util.concurrent.atomic.AtomicBoolean | ||
| import javax.annotation.concurrent.GuardedBy | ||
|
|
||
| import scala.collection.mutable.ListBuffer | ||
| import scala.jdk.CollectionConverters._ | ||
|
|
||
| import net.razorvine.pickle.Unpickler | ||
|
|
||
| import org.apache.spark.SparkEnv | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} | ||
| import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} | ||
| import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset} | ||
| import org.apache.spark.sql.execution.streaming.runtime.LongOffset | ||
|
|
||
|
|
||
| class PythonProfileMicroBatchStream | ||
| extends MicroBatchStream with Logging { | ||
|
|
||
| @GuardedBy("this") | ||
| private var readThread: Thread = null | ||
|
|
||
| @GuardedBy("this") | ||
| private val batches = new ListBuffer[java.util.List[java.util.Map[String, String]]] | ||
|
|
||
| @GuardedBy("this") | ||
| private var currentOffset: LongOffset = LongOffset(-1L) | ||
|
|
||
| @GuardedBy("this") | ||
| private var lastOffsetCommitted: LongOffset = LongOffset(-1L) | ||
|
|
||
| private val initialized: AtomicBoolean = new AtomicBoolean(false) | ||
|
|
||
| private def initialize(): Unit = synchronized { | ||
| readThread = new Thread(s"PythonProfileMicroBatchStream") { | ||
| setDaemon(true) | ||
|
|
||
| override def run(): Unit = { | ||
| val unpickler = new Unpickler | ||
| val extraChannel = SparkEnv.get.pythonWorkers.values | ||
| .head.getAllDaemonWorkers.map(_._1.extraChannel).head | ||
| extraChannel.foreach { s => | ||
| val inputStream = new DataInputStream(Channels.newInputStream(s)) | ||
| while (true) { | ||
| val len = inputStream.readInt() | ||
| val buf = new Array[Byte](len) | ||
| var totalRead = 0 | ||
| while (totalRead < len) { | ||
| val readNow = inputStream.read(buf, totalRead, len - totalRead) | ||
| assert(readNow != -1) | ||
| totalRead += readNow | ||
| } | ||
| currentOffset += 1 | ||
| batches.append( | ||
| unpickler.loads(buf).asInstanceOf[java.util.List[java.util.Map[String, String]]]) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| readThread.start() | ||
| } | ||
|
|
||
| override def initialOffset(): Offset = LongOffset(-1L) | ||
|
|
||
| override def latestOffset(): Offset = currentOffset | ||
|
|
||
| override def deserializeOffset(json: String): Offset = { | ||
| LongOffset(json.toLong) | ||
| } | ||
|
|
||
| override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = { | ||
| val startOrdinal = start.asInstanceOf[LongOffset].offset.toInt + 1 | ||
| val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 | ||
|
|
||
| val rawList = synchronized { | ||
| if (initialized.compareAndSet(false, true)) { | ||
| initialize() | ||
| } | ||
|
|
||
| val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 | ||
| val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 | ||
| batches.slice(sliceStart, sliceEnd) | ||
| } | ||
|
|
||
| Array(PythonProfileInputPartition(rawList)) | ||
| } | ||
|
|
||
| override def createReaderFactory(): PartitionReaderFactory = | ||
| (partition: InputPartition) => { | ||
| val stats = partition.asInstanceOf[PythonProfileInputPartition].stats | ||
| new PartitionReader[InternalRow] { | ||
| private var currentIdx = -1 | ||
|
|
||
| override def next(): Boolean = { | ||
| currentIdx += 1 | ||
| currentIdx < stats.size | ||
| } | ||
|
|
||
| override def get(): InternalRow = { | ||
| InternalRow.fromSeq( | ||
| CatalystTypeConverters.convertToCatalyst( | ||
| stats(currentIdx).asScala.toSeq.map(_.asScala)) :: Nil) | ||
| } | ||
|
|
||
| override def close(): Unit = {} | ||
| } | ||
| } | ||
|
|
||
| override def commit(end: Offset): Unit = synchronized { | ||
| val newOffset = end.asInstanceOf[LongOffset] | ||
|
|
||
| val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt | ||
|
|
||
| if (offsetDiff < 0) { | ||
| throw new IllegalStateException( | ||
| s"Offsets committed out of order: $lastOffsetCommitted followed by $end") | ||
| } | ||
|
|
||
| batches.dropInPlace(offsetDiff) | ||
| lastOffsetCommitted = newOffset | ||
| } | ||
|
|
||
| override def toString: String = s"PythonProfile" | ||
|
|
||
| override def stop(): Unit = { } | ||
| } | ||
|
|
||
| case class PythonProfileInputPartition( | ||
| stats: ListBuffer[java.util.List[java.util.Map[String, String]]]) extends InputPartition |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
obviously make configurable later