66#include < chrono>
77#include < cstdint>
88#include < cstdlib>
9+ #include < exception>
910#include < format>
1011#include < iostream>
1112#include < memory>
1213#include < random>
1314#include < stdexcept>
1415#include < string>
16+ #include < string_view>
1517
1618#include " oneapi/tbb/global_control.h"
1719#include " util/include/util.hpp"
1820
21+ namespace {
22+ [[maybe_unused]] void SyncGTestSeed () {
23+ unsigned int seed = 0 ;
24+ int rank = -1 ;
25+ MPI_Comm_rank (MPI_COMM_WORLD, &rank);
26+ if (rank == 0 ) {
27+ try {
28+ seed = std::random_device{}();
29+ } catch (...) {
30+ seed = 0 ;
31+ }
32+ if (seed == 0 ) {
33+ const auto now = static_cast <std::uint64_t >(std::chrono::steady_clock::now ().time_since_epoch ().count ());
34+ seed = static_cast <unsigned int >(((now & 0x7fffffffULL ) | 1ULL ));
35+ }
36+ }
37+ MPI_Bcast (&seed, 1 , MPI_UNSIGNED, 0 , MPI_COMM_WORLD);
38+ ::testing::GTEST_FLAG (random_seed) = static_cast<int>(seed);
39+ }
40+
41+ [[maybe_unused]] void SyncGTestFilter () {
42+ int rank = -1 ;
43+ MPI_Comm_rank (MPI_COMM_WORLD, &rank);
44+ std::string filter = (rank == 0 ) ? ::testing::GTEST_FLAG (filter) : std::string{};
45+ int len = static_cast <int >(filter.size ());
46+ MPI_Bcast (&len, 1 , MPI_INT, 0 , MPI_COMM_WORLD);
47+ if (rank != 0 ) {
48+ filter.resize (static_cast <std::size_t >(len));
49+ }
50+ if (len > 0 ) {
51+ MPI_Bcast (filter.data (), len, MPI_CHAR, 0 , MPI_COMM_WORLD);
52+ }
53+ ::testing::GTEST_FLAG (filter) = filter;
54+ }
55+
56+ [[maybe_unused]] bool HasFlag (int argc, char **argv, std::string_view flag) {
57+ for (int i = 1 ; i < argc; ++i) {
58+ if (argv[i] != nullptr && std::string_view (argv[i]) == flag) {
59+ return true ;
60+ }
61+ }
62+ return false ;
63+ }
64+ } // namespace
65+
1966namespace ppc ::runners {
2067
2168void UnreadMessagesDetector::OnTestEnd (const ::testing::TestInfo & /* test_info*/ ) {
@@ -51,6 +98,8 @@ void WorkerTestFailurePrinter::OnTestEnd(const ::testing::TestInfo &test_info) {
5198 }
5299 PrintProcessRank ();
53100 base_->OnTestEnd (test_info);
101+ // Abort the whole MPI job on any test failure to avoid other ranks hanging on barriers.
102+ MPI_Abort (MPI_COMM_WORLD, EXIT_FAILURE);
54103}
55104
56105void WorkerTestFailurePrinter::OnTestPartResult (const ::testing::TestPartResult &test_part_result) {
@@ -76,6 +125,63 @@ int RunAllTests() {
76125 }
77126 return status;
78127}
128+
129+ void SyncGTestSeed () {
130+ unsigned int seed = 0 ;
131+ int rank = -1 ;
132+ MPI_Comm_rank (MPI_COMM_WORLD, &rank);
133+ if (rank == 0 ) {
134+ try {
135+ seed = std::random_device{}();
136+ } catch (...) {
137+ seed = 0 ;
138+ }
139+ if (seed == 0 ) {
140+ const auto now = static_cast <std::uint64_t >(std::chrono::steady_clock::now ().time_since_epoch ().count ());
141+ seed = static_cast <unsigned int >(((now & 0x7fffffffULL ) | 1ULL ));
142+ }
143+ }
144+ MPI_Bcast (&seed, 1 , MPI_UNSIGNED, 0 , MPI_COMM_WORLD);
145+ ::testing::GTEST_FLAG (random_seed) = static_cast<int>(seed);
146+ }
147+
148+ void SyncGTestFilter () {
149+ int rank = -1 ;
150+ MPI_Comm_rank (MPI_COMM_WORLD, &rank);
151+ std::string filter = (rank == 0 ) ? ::testing::GTEST_FLAG (filter) : std::string{};
152+ int len = static_cast <int >(filter.size ());
153+ MPI_Bcast (&len, 1 , MPI_INT, 0 , MPI_COMM_WORLD);
154+ if (rank != 0 ) {
155+ filter.resize (static_cast <std::size_t >(len));
156+ }
157+ if (len > 0 ) {
158+ MPI_Bcast (filter.data (), len, MPI_CHAR, 0 , MPI_COMM_WORLD);
159+ }
160+ ::testing::GTEST_FLAG (filter) = filter;
161+ }
162+
163+ bool HasFlag (int argc, char **argv, std::string_view flag) {
164+ for (int i = 1 ; i < argc; ++i) {
165+ if (argv[i] != nullptr && std::string_view (argv[i]) == flag) {
166+ return true ;
167+ }
168+ }
169+ return false ;
170+ }
171+
172+ int RunAllTestsSafely () {
173+ try {
174+ return RunAllTests ();
175+ } catch (const std::exception &e) {
176+ std::cerr << std::format (" [ ERROR ] Exception after tests: {}" , e.what ()) << ' \n ' ;
177+ MPI_Abort (MPI_COMM_WORLD, EXIT_FAILURE);
178+ return EXIT_FAILURE;
179+ } catch (...) {
180+ std::cerr << " [ ERROR ] Unknown exception after tests" << ' \n ' ;
181+ MPI_Abort (MPI_COMM_WORLD, EXIT_FAILURE);
182+ return EXIT_FAILURE;
183+ }
184+ }
79185} // namespace
80186
81187int Init (int argc, char **argv) {
@@ -91,36 +197,21 @@ int Init(int argc, char **argv) {
91197
92198 ::testing::InitGoogleTest (&argc, argv);
93199
94- // Ensure consistent GoogleTest shuffle order across all MPI ranks.
95- unsigned int seed = 0 ;
96- int rank_for_seed = -1 ;
97- MPI_Comm_rank (MPI_COMM_WORLD, &rank_for_seed);
98-
99- if (rank_for_seed == 0 ) {
100- try {
101- seed = std::random_device{}();
102- } catch (...) {
103- seed = 0 ;
104- }
105- if (seed == 0 ) {
106- const auto now = static_cast <std::uint64_t >(std::chrono::steady_clock::now ().time_since_epoch ().count ());
107- seed = static_cast <unsigned int >(((now & 0x7fffffffULL ) | 1ULL ));
108- }
109- }
110-
111- MPI_Bcast (&seed, 1 , MPI_UNSIGNED, 0 , MPI_COMM_WORLD);
112- ::testing::GTEST_FLAG (random_seed) = static_cast<int>(seed);
200+ // Synchronize GoogleTest internals across ranks to avoid divergence
201+ SyncGTestSeed ();
202+ SyncGTestFilter ();
113203
114204 auto &listeners = ::testing::UnitTest::GetInstance ()->listeners ();
115205 int rank = -1 ;
116206 MPI_Comm_rank (MPI_COMM_WORLD, &rank);
117- if (rank != 0 && (argc < 2 || argv[1 ] != std::string (" --print-workers" ))) {
207+ const bool print_workers = HasFlag (argc, argv, " --print-workers" );
208+ if (rank != 0 && !print_workers) {
118209 auto *listener = listeners.Release (listeners.default_result_printer ());
119210 listeners.Append (new WorkerTestFailurePrinter (std::shared_ptr<::testing::TestEventListener>(listener)));
120211 }
121212 listeners.Append (new UnreadMessagesDetector ());
122213
123- auto status = RunAllTests ();
214+ const int status = RunAllTestsSafely ();
124215
125216 const int finalize_res = MPI_Finalize ();
126217 if (finalize_res != MPI_SUCCESS) {
0 commit comments