Skip to content

Commit 99b07ca

Browse files
muditgokhale2copybara-github
authored andcommitted
Add multi-threading to trace viewer Reduce
PiperOrigin-RevId: 847647539
1 parent 9839b09 commit 99b07ca

File tree

3 files changed

+197
-51
lines changed

3 files changed

+197
-51
lines changed

xprof/convert/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ cc_library(
228228
":repository",
229229
":tool_options",
230230
":xplane_to_trace_container",
231+
":xprof_thread_pool_executor",
231232
"@com_google_absl//absl/log",
232233
"@com_google_absl//absl/status",
233234
"@com_google_absl//absl/status:statusor",
@@ -240,6 +241,7 @@ cc_library(
240241
"@org_xprof//xprof/convert/trace_viewer:trace_options",
241242
"@org_xprof//xprof/convert/trace_viewer:trace_viewer_visibility",
242243
"@tsl//tsl/platform:path",
244+
"@tsl//tsl/platform:platform_port",
243245
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
244246
"@xla//xla/tsl/platform:env",
245247
"@xla//xla/tsl/platform:errors",

xprof/convert/streaming_trace_viewer_processor.cc

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "xprof/convert/streaming_trace_viewer_processor.h"
22

3+
#include <algorithm>
34
#include <cmath>
45
#include <cstdint>
56
#include <memory>
@@ -18,6 +19,7 @@
1819
#include "xla/tsl/platform/file_system.h"
1920
#include "xla/tsl/platform/statusor.h"
2021
#include "xla/tsl/profiler/utils/timespan.h"
22+
#include "tsl/platform/cpu_info.h"
2123
#include "tsl/platform/path.h"
2224
#include "tsl/profiler/protobuf/xplane.pb.h"
2325
#include "xprof/convert/preprocess_single_host_xplane.h"
@@ -30,9 +32,12 @@
3032
#include "xprof/convert/trace_viewer/trace_options.h"
3133
#include "xprof/convert/trace_viewer/trace_viewer_visibility.h"
3234
#include "xprof/convert/xplane_to_trace_container.h"
35+
#include "xprof/convert/xprof_thread_pool_executor.h"
3336

3437
namespace xprof {
3538

39+
using internal::GetTraceViewOption;
40+
using internal::TraceViewOption;
3641
using ::tensorflow::profiler::IOBufferAdapter;
3742
using ::tensorflow::profiler::JsonTraceOptions;
3843
using ::tensorflow::profiler::RawData;
@@ -43,9 +48,8 @@ using ::tensorflow::profiler::TraceEventsContainer;
4348
using ::tensorflow::profiler::TraceEventsLevelDbFilePaths;
4449
using ::tensorflow::profiler::TraceOptionsFromToolOptions;
4550
using ::tensorflow::profiler::TraceVisibilityFilter;
51+
using ::tensorflow::profiler::XprofThreadPoolExecutor;
4652
using ::tensorflow::profiler::XSpace;
47-
using internal::GetTraceViewOption;
48-
using internal::TraceViewOption;
4953

5054
namespace {
5155
// Traces with events less than threshold will be disabled from streaming.
@@ -63,7 +67,6 @@ absl::Status StreamingTraceViewerProcessor::ProcessSession(
6367

6468
// TODO: b/452217676 - Optimize this to process hosts in parallel.
6569
for (int i = 0; i < session_snapshot.XSpaceSize(); ++i) {
66-
int host_id = i+1;
6770
google::protobuf::Arena arena;
6871
TF_ASSIGN_OR_RETURN(XSpace * xspace, session_snapshot.GetXSpace(i, &arena));
6972
PreprocessSingleHostXSpace(xspace, /*step_grouping=*/true,
@@ -143,7 +146,7 @@ absl::Status StreamingTraceViewerProcessor::ProcessSession(
143146
file_paths, std::move(trace_events_filter),
144147
std::move(visibility_filter), kDisableStreamingThreshold));
145148
}
146-
merged_trace_container.Merge(std::move(trace_container), host_id);
149+
merged_trace_container.Merge(std::move(trace_container), i + 1);
147150
}
148151

149152
std::string trace_viewer_json;
@@ -250,12 +253,13 @@ absl::StatusOr<TraceEventsContainer> LoadTraceContainerForHost(
250253
if (!metadata_path.has_value() ||
251254
!tsl::Env::Default()->FileExists(*metadata_path).ok()) {
252255
return tsl::errors::Internal("Could not find metadata file for host: ",
253-
hostname, ", path: ", *metadata_path);
256+
hostname,
257+
", path: ", metadata_path.value_or(""));
254258
}
255259
if (!trie_path.has_value() ||
256260
!tsl::Env::Default()->FileExists(*trie_path).ok()) {
257261
return tsl::errors::Internal("Could not find trie file for host: ",
258-
hostname, ", path: ", *trie_path);
262+
hostname, ", path: ", trie_path.value_or(""));
259263
}
260264
file_paths.trace_events_metadata_file_path = *metadata_path;
261265
file_paths.trace_events_prefix_trie_file_path = *trie_path;
@@ -306,18 +310,41 @@ absl::Status StreamingTraceViewerProcessor::Reduce(
306310
tensorflow::profiler::TraceOptions profiler_trace_options =
307311
TraceOptionsFromToolOptions(options_);
308312

313+
int num_hosts = map_output_files.size();
314+
int num_threads = std::min(num_hosts, tsl::port::MaxParallelism());
315+
std::vector<absl::StatusOr<TraceEventsContainer>> trace_containers(num_hosts);
316+
317+
{
318+
XprofThreadPoolExecutor executor("StreamingTraceViewerReduce", num_threads);
319+
for (int i = 0; i < num_hosts; ++i) {
320+
executor.Execute([&session_snapshot, &map_output_files, &trace_option,
321+
&profiler_trace_options, &trace_containers, i]() {
322+
trace_containers[i] =
323+
LoadTraceContainerForHost(session_snapshot, map_output_files[i],
324+
trace_option, profiler_trace_options);
325+
});
326+
}
327+
executor.JoinAll();
328+
}
329+
309330
TraceEventsContainer merged_trace_container;
331+
int successful_hosts = 0;
332+
for (int i = 0; i < num_hosts; ++i) {
333+
if (!trace_containers[i].ok()) {
334+
LOG(ERROR) << "Skipping host " << i
335+
<< " due to failure: " << trace_containers[i].status();
336+
continue;
337+
}
310338

311-
for (int i = 0; i < map_output_files.size(); ++i) {
312-
const std::string& trace_events_sstable_path = map_output_files[i];
313-
int host_id = i + 1;
339+
TF_ASSIGN_OR_RETURN(TraceEventsContainer trace_container,
340+
std::move(trace_containers[i]));
314341

315-
TF_ASSIGN_OR_RETURN(
316-
TraceEventsContainer trace_container,
317-
LoadTraceContainerForHost(session_snapshot, trace_events_sstable_path,
318-
trace_option, profiler_trace_options));
342+
merged_trace_container.Merge(std::move(trace_container), i + 1);
343+
successful_hosts++;
344+
}
319345

320-
merged_trace_container.Merge(std::move(trace_container), host_id);
346+
if (successful_hosts == 0) {
347+
return absl::InternalError("No hosts with valid trace data.");
321348
}
322349

323350
std::string trace_viewer_json;

0 commit comments

Comments
 (0)