Skip to content

Commit 54ebdfb

Browse files
committed
infra: stabilize MPI tests and prevent hangs
- Sync gtest random_seed + filter across ranks - Abort on worker failure; wrap RunAllTests with MPI_Abort on exceptions - Make PPC_TEST_TMPDIR per MPI rank - Pass env to mpiexec on Windows (-env), keep -x on *nix
1 parent 5d9ac5d commit 54ebdfb

File tree

3 files changed

+178
-27
lines changed

3 files changed

+178
-27
lines changed

modules/runners/src/runners.cpp

Lines changed: 112 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,63 @@
66
#include <chrono>
77
#include <cstdint>
88
#include <cstdlib>
9+
#include <exception>
910
#include <format>
1011
#include <iostream>
1112
#include <memory>
1213
#include <random>
1314
#include <stdexcept>
1415
#include <string>
16+
#include <string_view>
1517

1618
#include "oneapi/tbb/global_control.h"
1719
#include "util/include/util.hpp"
1820

21+
namespace {
22+
[[maybe_unused]] void SyncGTestSeed() {
23+
unsigned int seed = 0;
24+
int rank = -1;
25+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
26+
if (rank == 0) {
27+
try {
28+
seed = std::random_device{}();
29+
} catch (...) {
30+
seed = 0;
31+
}
32+
if (seed == 0) {
33+
const auto now = static_cast<std::uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
34+
seed = static_cast<unsigned int>(((now & 0x7fffffffULL) | 1ULL));
35+
}
36+
}
37+
MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD);
38+
::testing::GTEST_FLAG(random_seed) = static_cast<int>(seed);
39+
}
40+
41+
[[maybe_unused]] void SyncGTestFilter() {
42+
int rank = -1;
43+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
44+
std::string filter = (rank == 0) ? ::testing::GTEST_FLAG(filter) : std::string{};
45+
int len = static_cast<int>(filter.size());
46+
MPI_Bcast(&len, 1, MPI_INT, 0, MPI_COMM_WORLD);
47+
if (rank != 0) {
48+
filter.resize(static_cast<std::size_t>(len));
49+
}
50+
if (len > 0) {
51+
MPI_Bcast(filter.data(), len, MPI_CHAR, 0, MPI_COMM_WORLD);
52+
}
53+
::testing::GTEST_FLAG(filter) = filter;
54+
}
55+
56+
[[maybe_unused]] bool HasFlag(int argc, char **argv, std::string_view flag) {
57+
for (int i = 1; i < argc; ++i) {
58+
if (argv[i] != nullptr && std::string_view(argv[i]) == flag) {
59+
return true;
60+
}
61+
}
62+
return false;
63+
}
64+
} // namespace
65+
1966
namespace ppc::runners {
2067

2168
void UnreadMessagesDetector::OnTestEnd(const ::testing::TestInfo & /*test_info*/) {
@@ -51,6 +98,8 @@ void WorkerTestFailurePrinter::OnTestEnd(const ::testing::TestInfo &test_info) {
5198
}
5299
PrintProcessRank();
53100
base_->OnTestEnd(test_info);
101+
// Abort the whole MPI job on any test failure to avoid other ranks hanging on barriers.
102+
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
54103
}
55104

56105
void WorkerTestFailurePrinter::OnTestPartResult(const ::testing::TestPartResult &test_part_result) {
@@ -76,6 +125,63 @@ int RunAllTests() {
76125
}
77126
return status;
78127
}
128+
129+
void SyncGTestSeed() {
130+
unsigned int seed = 0;
131+
int rank = -1;
132+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
133+
if (rank == 0) {
134+
try {
135+
seed = std::random_device{}();
136+
} catch (...) {
137+
seed = 0;
138+
}
139+
if (seed == 0) {
140+
const auto now = static_cast<std::uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
141+
seed = static_cast<unsigned int>(((now & 0x7fffffffULL) | 1ULL));
142+
}
143+
}
144+
MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD);
145+
::testing::GTEST_FLAG(random_seed) = static_cast<int>(seed);
146+
}
147+
148+
void SyncGTestFilter() {
149+
int rank = -1;
150+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
151+
std::string filter = (rank == 0) ? ::testing::GTEST_FLAG(filter) : std::string{};
152+
int len = static_cast<int>(filter.size());
153+
MPI_Bcast(&len, 1, MPI_INT, 0, MPI_COMM_WORLD);
154+
if (rank != 0) {
155+
filter.resize(static_cast<std::size_t>(len));
156+
}
157+
if (len > 0) {
158+
MPI_Bcast(filter.data(), len, MPI_CHAR, 0, MPI_COMM_WORLD);
159+
}
160+
::testing::GTEST_FLAG(filter) = filter;
161+
}
162+
163+
bool HasFlag(int argc, char **argv, std::string_view flag) {
164+
for (int i = 1; i < argc; ++i) {
165+
if (argv[i] != nullptr && std::string_view(argv[i]) == flag) {
166+
return true;
167+
}
168+
}
169+
return false;
170+
}
171+
172+
int RunAllTestsSafely() {
173+
try {
174+
return RunAllTests();
175+
} catch (const std::exception &e) {
176+
std::cerr << std::format("[ ERROR ] Exception after tests: {}", e.what()) << '\n';
177+
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
178+
return EXIT_FAILURE;
179+
} catch (...) {
180+
std::cerr << "[ ERROR ] Unknown exception after tests" << '\n';
181+
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
182+
return EXIT_FAILURE;
183+
}
184+
}
79185
} // namespace
80186

81187
int Init(int argc, char **argv) {
@@ -91,36 +197,21 @@ int Init(int argc, char **argv) {
91197

92198
::testing::InitGoogleTest(&argc, argv);
93199

94-
// Ensure consistent GoogleTest shuffle order across all MPI ranks.
95-
unsigned int seed = 0;
96-
int rank_for_seed = -1;
97-
MPI_Comm_rank(MPI_COMM_WORLD, &rank_for_seed);
98-
99-
if (rank_for_seed == 0) {
100-
try {
101-
seed = std::random_device{}();
102-
} catch (...) {
103-
seed = 0;
104-
}
105-
if (seed == 0) {
106-
const auto now = static_cast<std::uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
107-
seed = static_cast<unsigned int>(((now & 0x7fffffffULL) | 1ULL));
108-
}
109-
}
110-
111-
MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD);
112-
::testing::GTEST_FLAG(random_seed) = static_cast<int>(seed);
200+
// Synchronize GoogleTest internals across ranks to avoid divergence
201+
SyncGTestSeed();
202+
SyncGTestFilter();
113203

114204
auto &listeners = ::testing::UnitTest::GetInstance()->listeners();
115205
int rank = -1;
116206
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
117-
if (rank != 0 && (argc < 2 || argv[1] != std::string("--print-workers"))) {
207+
const bool print_workers = HasFlag(argc, argv, "--print-workers");
208+
if (rank != 0 && !print_workers) {
118209
auto *listener = listeners.Release(listeners.default_result_printer());
119210
listeners.Append(new WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener>(listener)));
120211
}
121212
listeners.Append(new UnreadMessagesDetector());
122213

123-
auto status = RunAllTests();
214+
const int status = RunAllTestsSafely();
124215

125216
const int finalize_res = MPI_Finalize();
126217
if (finalize_res != MPI_SUCCESS) {

modules/util/include/util.hpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <algorithm>
4+
#include <array>
45
#include <atomic>
56
#include <cctype>
67
#include <cstdint>
@@ -26,6 +27,7 @@
2627
#include <gtest/gtest.h>
2728

2829
#include <libenvpp/detail/environment.hpp>
30+
#include <libenvpp/detail/get.hpp>
2931
#include <nlohmann/json.hpp>
3032

3133
/// @brief JSON namespace used for settings and config parsing.
@@ -123,7 +125,19 @@ class ScopedPerTestEnv {
123125
private:
124126
static std::string CreateTmpDir(const std::string &token) {
125127
namespace fs = std::filesystem;
126-
const fs::path tmp = fs::temp_directory_path() / (std::string("ppc_test_") + token);
128+
auto make_rank_suffix = []() -> std::string {
129+
// Derive rank from common MPI env vars without including MPI headers
130+
constexpr std::array<std::string_view, 5> kRankVars = {"OMPI_COMM_WORLD_RANK", "PMI_RANK", "PMIX_RANK",
131+
"SLURM_PROCID", "MSMPI_RANK"};
132+
for (auto name : kRankVars) {
133+
if (auto r = env::get<int>(name); r.has_value() && r.value() >= 0) {
134+
return std::string("_rank_") + std::to_string(r.value());
135+
}
136+
}
137+
return std::string{};
138+
};
139+
const std::string rank_suffix = IsUnderMpirun() ? make_rank_suffix() : std::string{};
140+
const fs::path tmp = fs::temp_directory_path() / (std::string("ppc_test_") + token + rank_suffix);
127141
std::error_code ec;
128142
fs::create_directories(tmp, ec);
129143
(void)ec;

scripts/run_tests.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self, verbose=False):
5353
self.mpi_exec = "mpiexec"
5454
else:
5555
self.mpi_exec = "mpirun"
56+
self.platform = platform.system()
5657

5758
@staticmethod
5859
def __get_project_path():
@@ -133,10 +134,34 @@ def run_processes(self, additional_mpi_args):
133134
raise EnvironmentError(
134135
"Required environment variable 'PPC_NUM_PROC' is not set."
135136
)
136-
137-
mpi_running = (
138-
[self.mpi_exec] + shlex.split(additional_mpi_args) + ["-np", ppc_num_proc]
139-
)
137+
if self.platform == "Windows":
138+
mpi_running = (
139+
[self.mpi_exec]
140+
+ shlex.split(additional_mpi_args)
141+
+ [
142+
"-env",
143+
"PPC_NUM_THREADS",
144+
self.__ppc_env["PPC_NUM_THREADS"],
145+
"-env",
146+
"OMP_NUM_THREADS",
147+
self.__ppc_env["OMP_NUM_THREADS"],
148+
"-n",
149+
ppc_num_proc,
150+
]
151+
)
152+
else:
153+
mpi_running = (
154+
[self.mpi_exec]
155+
+ shlex.split(additional_mpi_args)
156+
+ [
157+
"-x",
158+
"PPC_NUM_THREADS",
159+
"-x",
160+
"OMP_NUM_THREADS",
161+
"-np",
162+
ppc_num_proc,
163+
]
164+
)
140165
if not self.__ppc_env.get("PPC_ASAN_RUN"):
141166
for task_type in ["all", "mpi"]:
142167
self.__run_exec(
@@ -147,7 +172,28 @@ def run_processes(self, additional_mpi_args):
147172

148173
def run_performance(self):
149174
if not self.__ppc_env.get("PPC_ASAN_RUN"):
150-
mpi_running = [self.mpi_exec, "-np", self.__ppc_num_proc]
175+
if self.platform == "Windows":
176+
mpi_running = [
177+
self.mpi_exec,
178+
"-env",
179+
"PPC_NUM_THREADS",
180+
self.__ppc_env["PPC_NUM_THREADS"],
181+
"-env",
182+
"OMP_NUM_THREADS",
183+
self.__ppc_env["OMP_NUM_THREADS"],
184+
"-n",
185+
self.__ppc_num_proc,
186+
]
187+
else:
188+
mpi_running = [
189+
self.mpi_exec,
190+
"-x",
191+
"PPC_NUM_THREADS",
192+
"-x",
193+
"OMP_NUM_THREADS",
194+
"-np",
195+
self.__ppc_num_proc,
196+
]
151197
for task_type in ["all", "mpi"]:
152198
self.__run_exec(
153199
mpi_running

0 commit comments

Comments
 (0)