Skip to content

Commit 666f068

Browse files
authored
[fix] Stabilize MPI tests and prevent hangs (#672)
- Sync gtest random_seed + filter across ranks - Abort on worker failure; wrap RunAllTests with MPI_Abort on exceptions - Make PPC_TEST_TMPDIR per MPI rank - Pass env to mpiexec on Windows (-env), keep -x on *nix
1 parent 5d9ac5d commit 666f068

File tree

3 files changed

+170
-27
lines changed

3 files changed

+170
-27
lines changed

modules/runners/src/runners.cpp

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
#include <chrono>
77
#include <cstdint>
88
#include <cstdlib>
9+
#include <exception>
910
#include <format>
1011
#include <iostream>
1112
#include <memory>
1213
#include <random>
1314
#include <stdexcept>
1415
#include <string>
16+
#include <string_view>
1517

1618
#include "oneapi/tbb/global_control.h"
1719
#include "util/include/util.hpp"
@@ -51,6 +53,8 @@ void WorkerTestFailurePrinter::OnTestEnd(const ::testing::TestInfo &test_info) {
5153
}
5254
PrintProcessRank();
5355
base_->OnTestEnd(test_info);
56+
// Abort the whole MPI job on any test failure to avoid other ranks hanging on barriers.
57+
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
5458
}
5559

5660
void WorkerTestFailurePrinter::OnTestPartResult(const ::testing::TestPartResult &test_part_result) {
@@ -76,6 +80,63 @@ int RunAllTests() {
7680
}
7781
return status;
7882
}
83+
84+
void SyncGTestSeed() {
85+
unsigned int seed = 0;
86+
int rank = -1;
87+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
88+
if (rank == 0) {
89+
try {
90+
seed = std::random_device{}();
91+
} catch (...) {
92+
seed = 0;
93+
}
94+
if (seed == 0) {
95+
const auto now = static_cast<std::uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
96+
seed = static_cast<unsigned int>(((now & 0x7fffffffULL) | 1ULL));
97+
}
98+
}
99+
MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD);
100+
::testing::GTEST_FLAG(random_seed) = static_cast<int>(seed);
101+
}
102+
103+
void SyncGTestFilter() {
104+
int rank = -1;
105+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
106+
std::string filter = (rank == 0) ? ::testing::GTEST_FLAG(filter) : std::string{};
107+
int len = static_cast<int>(filter.size());
108+
MPI_Bcast(&len, 1, MPI_INT, 0, MPI_COMM_WORLD);
109+
if (rank != 0) {
110+
filter.resize(static_cast<std::size_t>(len));
111+
}
112+
if (len > 0) {
113+
MPI_Bcast(filter.data(), len, MPI_CHAR, 0, MPI_COMM_WORLD);
114+
}
115+
::testing::GTEST_FLAG(filter) = filter;
116+
}
117+
118+
bool HasFlag(int argc, char **argv, std::string_view flag) {
119+
for (int i = 1; i < argc; ++i) {
120+
if (argv[i] != nullptr && std::string_view(argv[i]) == flag) {
121+
return true;
122+
}
123+
}
124+
return false;
125+
}
126+
127+
int RunAllTestsSafely() {
128+
try {
129+
return RunAllTests();
130+
} catch (const std::exception &e) {
131+
std::cerr << std::format("[ ERROR ] Exception after tests: {}", e.what()) << '\n';
132+
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
133+
return EXIT_FAILURE;
134+
} catch (...) {
135+
std::cerr << "[ ERROR ] Unknown exception after tests" << '\n';
136+
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
137+
return EXIT_FAILURE;
138+
}
139+
}
79140
} // namespace
80141

81142
int Init(int argc, char **argv) {
@@ -91,36 +152,21 @@ int Init(int argc, char **argv) {
91152

92153
::testing::InitGoogleTest(&argc, argv);
93154

94-
// Ensure consistent GoogleTest shuffle order across all MPI ranks.
95-
unsigned int seed = 0;
96-
int rank_for_seed = -1;
97-
MPI_Comm_rank(MPI_COMM_WORLD, &rank_for_seed);
98-
99-
if (rank_for_seed == 0) {
100-
try {
101-
seed = std::random_device{}();
102-
} catch (...) {
103-
seed = 0;
104-
}
105-
if (seed == 0) {
106-
const auto now = static_cast<std::uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
107-
seed = static_cast<unsigned int>(((now & 0x7fffffffULL) | 1ULL));
108-
}
109-
}
110-
111-
MPI_Bcast(&seed, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD);
112-
::testing::GTEST_FLAG(random_seed) = static_cast<int>(seed);
155+
// Synchronize GoogleTest internals across ranks to avoid divergence
156+
SyncGTestSeed();
157+
SyncGTestFilter();
113158

114159
auto &listeners = ::testing::UnitTest::GetInstance()->listeners();
115160
int rank = -1;
116161
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
117-
if (rank != 0 && (argc < 2 || argv[1] != std::string("--print-workers"))) {
162+
const bool print_workers = HasFlag(argc, argv, "--print-workers");
163+
if (rank != 0 && !print_workers) {
118164
auto *listener = listeners.Release(listeners.default_result_printer());
119165
listeners.Append(new WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener>(listener)));
120166
}
121167
listeners.Append(new UnreadMessagesDetector());
122168

123-
auto status = RunAllTests();
169+
const int status = RunAllTestsSafely();
124170

125171
const int finalize_res = MPI_Finalize();
126172
if (finalize_res != MPI_SUCCESS) {

modules/util/include/util.hpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <algorithm>
4+
#include <array>
45
#include <atomic>
56
#include <cctype>
67
#include <cstdint>
@@ -26,6 +27,7 @@
2627
#include <gtest/gtest.h>
2728

2829
#include <libenvpp/detail/environment.hpp>
30+
#include <libenvpp/detail/get.hpp>
2931
#include <nlohmann/json.hpp>
3032

3133
/// @brief JSON namespace used for settings and config parsing.
@@ -123,7 +125,19 @@ class ScopedPerTestEnv {
123125
private:
124126
static std::string CreateTmpDir(const std::string &token) {
125127
namespace fs = std::filesystem;
126-
const fs::path tmp = fs::temp_directory_path() / (std::string("ppc_test_") + token);
128+
auto make_rank_suffix = []() -> std::string {
129+
// Derive rank from common MPI env vars without including MPI headers
130+
constexpr std::array<std::string_view, 5> kRankVars = {"OMPI_COMM_WORLD_RANK", "PMI_RANK", "PMIX_RANK",
131+
"SLURM_PROCID", "MSMPI_RANK"};
132+
for (auto name : kRankVars) {
133+
if (auto r = env::get<int>(name); r.has_value() && r.value() >= 0) {
134+
return std::string("_rank_") + std::to_string(r.value());
135+
}
136+
}
137+
return std::string{};
138+
};
139+
const std::string rank_suffix = IsUnderMpirun() ? make_rank_suffix() : std::string{};
140+
const fs::path tmp = fs::temp_directory_path() / (std::string("ppc_test_") + token + rank_suffix);
127141
std::error_code ec;
128142
fs::create_directories(tmp, ec);
129143
(void)ec;

scripts/run_tests.py

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ def __init__(self, verbose=False):
5353
self.mpi_exec = "mpiexec"
5454
else:
5555
self.mpi_exec = "mpirun"
56+
self.platform = platform.system()
57+
58+
# Detect MPI implementation to choose compatible flags
59+
self.mpi_env_mode = "unknown" # one of: openmpi, mpich, unknown
60+
self.mpi_np_flag = "-np"
61+
if self.platform == "Windows":
62+
# MSMPI uses -env and -n
63+
self.mpi_env_mode = "mpich"
64+
self.mpi_np_flag = "-n"
65+
else:
66+
self.mpi_env_mode, self.mpi_np_flag = self.__detect_mpi_impl()
5667

5768
@staticmethod
5869
def __get_project_path():
@@ -88,6 +99,81 @@ def __run_exec(self, command):
8899
if result.returncode != 0:
89100
raise Exception(f"Subprocess return {result.returncode}.")
90101

102+
def __detect_mpi_impl(self):
103+
"""Detect MPI implementation and return (env_mode, np_flag).
104+
env_mode: 'openmpi' -> use '-x VAR', 'mpich' -> use '-genvlist VAR1,VAR2', 'unknown' -> pass no env flags.
105+
np_flag: '-np' for OpenMPI/unknown, '-n' for MPICH-family.
106+
"""
107+
probes = (["--version"], ["-V"], ["-v"], ["--help"], ["-help"])
108+
out = ""
109+
for args in probes:
110+
try:
111+
proc = subprocess.run(
112+
[self.mpi_exec] + list(args),
113+
stdout=subprocess.PIPE,
114+
stderr=subprocess.STDOUT,
115+
text=True,
116+
)
117+
out = (proc.stdout or "").lower()
118+
if out:
119+
break
120+
except Exception:
121+
continue
122+
123+
if "open mpi" in out or "ompi" in out:
124+
return "openmpi", "-np"
125+
if (
126+
"hydra" in out
127+
or "mpich" in out
128+
or "intel(r) mpi" in out
129+
or "intel mpi" in out
130+
):
131+
return "mpich", "-n"
132+
return "unknown", "-np"
133+
134+
def __build_mpi_cmd(self, ppc_num_proc, additional_mpi_args):
135+
base = [self.mpi_exec] + shlex.split(additional_mpi_args)
136+
137+
if self.platform == "Windows":
138+
# MS-MPI style
139+
env_args = [
140+
"-env",
141+
"PPC_NUM_THREADS",
142+
self.__ppc_env["PPC_NUM_THREADS"],
143+
"-env",
144+
"OMP_NUM_THREADS",
145+
self.__ppc_env["OMP_NUM_THREADS"],
146+
]
147+
np_args = ["-n", ppc_num_proc]
148+
return base + env_args + np_args
149+
150+
# Non-Windows
151+
if self.mpi_env_mode == "openmpi":
152+
env_args = [
153+
"-x",
154+
"PPC_NUM_THREADS",
155+
"-x",
156+
"OMP_NUM_THREADS",
157+
]
158+
np_flag = "-np"
159+
elif self.mpi_env_mode == "mpich":
160+
# Explicitly set env variables for all ranks
161+
env_args = [
162+
"-env",
163+
"PPC_NUM_THREADS",
164+
self.__ppc_env["PPC_NUM_THREADS"],
165+
"-env",
166+
"OMP_NUM_THREADS",
167+
self.__ppc_env["OMP_NUM_THREADS"],
168+
]
169+
np_flag = "-n"
170+
else:
171+
# Unknown MPI flavor: rely on environment inheritance and default to -np
172+
env_args = []
173+
np_flag = "-np"
174+
175+
return base + env_args + [np_flag, ppc_num_proc]
176+
91177
@staticmethod
92178
def __get_gtest_settings(repeats_count, type_task):
93179
command = [
@@ -133,10 +219,7 @@ def run_processes(self, additional_mpi_args):
133219
raise EnvironmentError(
134220
"Required environment variable 'PPC_NUM_PROC' is not set."
135221
)
136-
137-
mpi_running = (
138-
[self.mpi_exec] + shlex.split(additional_mpi_args) + ["-np", ppc_num_proc]
139-
)
222+
mpi_running = self.__build_mpi_cmd(ppc_num_proc, additional_mpi_args)
140223
if not self.__ppc_env.get("PPC_ASAN_RUN"):
141224
for task_type in ["all", "mpi"]:
142225
self.__run_exec(
@@ -147,7 +230,7 @@ def run_processes(self, additional_mpi_args):
147230

148231
def run_performance(self):
149232
if not self.__ppc_env.get("PPC_ASAN_RUN"):
150-
mpi_running = [self.mpi_exec, "-np", self.__ppc_num_proc]
233+
mpi_running = self.__build_mpi_cmd(self.__ppc_num_proc, "")
151234
for task_type in ["all", "mpi"]:
152235
self.__run_exec(
153236
mpi_running

0 commit comments

Comments
 (0)