11#include " xprof/convert/streaming_trace_viewer_processor.h"
22
3+ #include < algorithm>
34#include < cmath>
45#include < cstdint>
56#include < memory>
1819#include " xla/tsl/platform/file_system.h"
1920#include " xla/tsl/platform/statusor.h"
2021#include " xla/tsl/profiler/utils/timespan.h"
22+ #include " tsl/platform/cpu_info.h"
2123#include " tsl/platform/path.h"
2224#include " tsl/profiler/protobuf/xplane.pb.h"
2325#include " xprof/convert/preprocess_single_host_xplane.h"
3032#include " xprof/convert/trace_viewer/trace_options.h"
3133#include " xprof/convert/trace_viewer/trace_viewer_visibility.h"
3234#include " xprof/convert/xplane_to_trace_container.h"
35+ #include " xprof/convert/xprof_thread_pool_executor.h"
3336
3437namespace xprof {
3538
39+ using internal::GetTraceViewOption;
40+ using internal::TraceViewOption;
3641using ::tensorflow::profiler::IOBufferAdapter;
3742using ::tensorflow::profiler::JsonTraceOptions;
3843using ::tensorflow::profiler::RawData;
@@ -43,9 +48,8 @@ using ::tensorflow::profiler::TraceEventsContainer;
4348using ::tensorflow::profiler::TraceEventsLevelDbFilePaths;
4449using ::tensorflow::profiler::TraceOptionsFromToolOptions;
4550using ::tensorflow::profiler::TraceVisibilityFilter;
51+ using ::tensorflow::profiler::XprofThreadPoolExecutor;
4652using ::tensorflow::profiler::XSpace;
47- using internal::GetTraceViewOption;
48- using internal::TraceViewOption;
4953
5054namespace {
5155// Traces with events less than threshold will be disabled from streaming.
@@ -63,7 +67,6 @@ absl::Status StreamingTraceViewerProcessor::ProcessSession(
6367
6468 // TODO: b/452217676 - Optimize this to process hosts in parallel.
6569 for (int i = 0 ; i < session_snapshot.XSpaceSize (); ++i) {
66- int host_id = i+1 ;
6770 google::protobuf::Arena arena;
6871 TF_ASSIGN_OR_RETURN (XSpace * xspace, session_snapshot.GetXSpace (i, &arena));
6972 PreprocessSingleHostXSpace (xspace, /* step_grouping=*/ true ,
@@ -143,7 +146,7 @@ absl::Status StreamingTraceViewerProcessor::ProcessSession(
143146 file_paths, std::move (trace_events_filter),
144147 std::move (visibility_filter), kDisableStreamingThreshold ));
145148 }
146- merged_trace_container.Merge (std::move (trace_container), host_id );
149+ merged_trace_container.Merge (std::move (trace_container), i + 1 );
147150 }
148151
149152 std::string trace_viewer_json;
@@ -250,12 +253,13 @@ absl::StatusOr<TraceEventsContainer> LoadTraceContainerForHost(
250253 if (!metadata_path.has_value () ||
251254 !tsl::Env::Default ()->FileExists (*metadata_path).ok ()) {
252255 return tsl::errors::Internal (" Could not find metadata file for host: " ,
253- hostname, " , path: " , *metadata_path);
256+ hostname,
257+ " , path: " , metadata_path.value_or (" " ));
254258 }
255259 if (!trie_path.has_value () ||
256260 !tsl::Env::Default ()->FileExists (*trie_path).ok ()) {
257261 return tsl::errors::Internal (" Could not find trie file for host: " ,
258- hostname, " , path: " , * trie_path);
262+ hostname, " , path: " , trie_path. value_or ( " " ) );
259263 }
260264 file_paths.trace_events_metadata_file_path = *metadata_path;
261265 file_paths.trace_events_prefix_trie_file_path = *trie_path;
@@ -306,18 +310,41 @@ absl::Status StreamingTraceViewerProcessor::Reduce(
306310 tensorflow::profiler::TraceOptions profiler_trace_options =
307311 TraceOptionsFromToolOptions (options_);
308312
313+ int num_hosts = map_output_files.size ();
314+ int num_threads = std::min (num_hosts, tsl::port::MaxParallelism ());
315+ std::vector<absl::StatusOr<TraceEventsContainer>> trace_containers (num_hosts);
316+
317+ {
318+ XprofThreadPoolExecutor executor (" StreamingTraceViewerReduce" , num_threads);
319+ for (int i = 0 ; i < num_hosts; ++i) {
320+ executor.Execute ([&session_snapshot, &map_output_files, &trace_option,
321+ &profiler_trace_options, &trace_containers, i]() {
322+ trace_containers[i] =
323+ LoadTraceContainerForHost (session_snapshot, map_output_files[i],
324+ trace_option, profiler_trace_options);
325+ });
326+ }
327+ executor.JoinAll ();
328+ }
329+
309330 TraceEventsContainer merged_trace_container;
331+ int successful_hosts = 0 ;
332+ for (int i = 0 ; i < num_hosts; ++i) {
333+ if (!trace_containers[i].ok ()) {
334+ LOG (ERROR) << " Skipping host " << i
335+ << " due to failure: " << trace_containers[i].status ();
336+ continue ;
337+ }
310338
311- for (int i = 0 ; i < map_output_files.size (); ++i) {
312- const std::string& trace_events_sstable_path = map_output_files[i];
313- int host_id = i + 1 ;
339+ TF_ASSIGN_OR_RETURN (TraceEventsContainer trace_container,
340+ std::move (trace_containers[i]));
314341
315- TF_ASSIGN_OR_RETURN (
316- TraceEventsContainer trace_container,
317- LoadTraceContainerForHost (session_snapshot, trace_events_sstable_path,
318- trace_option, profiler_trace_options));
342+ merged_trace_container.Merge (std::move (trace_container), i + 1 );
343+ successful_hosts++;
344+ }
319345
320- merged_trace_container.Merge (std::move (trace_container), host_id);
346+ if (successful_hosts == 0 ) {
347+ return absl::InternalError (" No hosts with valid trace data." );
321348 }
322349
323350 std::string trace_viewer_json;
0 commit comments