@@ -87,20 +87,22 @@ int main(int argc, char *argv[])
8787{
8888#if USE_MPI
8989 int provided;
90+ int localRank;
91+
9092 MPI_Init_thread (&argc, &argv, MPI_THREAD_FUNNELED, &provided);
91- if (provided < MPI_THREAD_FUNNELED) {
93+
94+ if (provided < MPI_THREAD_FUNNELED)
9295 MPI_Abort (MPI_COMM_WORLD, provided);
93- }
9496
9597 MPI_Comm_rank (MPI_COMM_WORLD, &rank);
9698 MPI_Comm_size (MPI_COMM_WORLD, &procs);
9799
98- // Each local rank on a given node will own a single device/GCD
99- MPI_Comm shmcomm ;
100+ // Each rank will run the benchmark on a single device
101+ MPI_Comm shared_comm ;
100102 MPI_Comm_split_type (MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0 ,
101- MPI_INFO_NULL, &shmcomm );
102- int localRank;
103- MPI_Comm_rank (shmcomm, &localRank);
103+ MPI_INFO_NULL, &shared_comm );
104+ MPI_Comm_rank (shared_comm, & localRank) ;
105+
104106 // Set device index to be the local MPI rank
105107 deviceIndex = localRank;
106108#endif
@@ -110,16 +112,17 @@ int main(int argc, char *argv[])
110112 if (!output_as_csv)
111113 {
112114#if USE_MPI
113- if (rank == 0 ) {
115+ if (rank == 0 )
114116#endif
117+ {
115118 std::cout
116119 << " BabelStream" << std::endl
117120 << " Version: " << VERSION_STRING << std::endl
118121 << " Implementation: " << IMPLEMENTATION_STRING << std::endl;
119122#if USE_MPI
120123 std::cout << " Number of MPI ranks: " << procs << std::endl;
121- }
122124#endif
125+ }
123126 }
124127
125128 if (use_float)
@@ -145,54 +148,48 @@ std::vector<std::vector<double>> run_all(Stream<T> *stream, T& sum)
145148 // Declare timers
146149 std::chrono::high_resolution_clock::time_point t1, t2;
147150
151+ #if USE_MPI
152+ // Set MPI data type for the dot-product reduction
153+ MPI_Datatype MPI_DTYPE = use_float ? MPI_FLOAT : MPI_DOUBLE;
154+ #endif
155+
148156 // Main loop
149157 for (unsigned int k = 0 ; k < num_times; k++)
150158 {
151- #if USE_MPI
152- MPI_Barrier (MPI_COMM_WORLD);
153- #endif
154159
155160 // Execute Copy
156161 t1 = std::chrono::high_resolution_clock::now ();
157162 stream->copy ();
158- #if USE_MPI
159- MPI_Barrier (MPI_COMM_WORLD);
160- #endif
161163 t2 = std::chrono::high_resolution_clock::now ();
162164 timings[0 ].push_back (std::chrono::duration_cast<std::chrono::duration<double > >(t2 - t1).count ());
163165
164166 // Execute Mul
165167 t1 = std::chrono::high_resolution_clock::now ();
166168 stream->mul ();
167- #if USE_MPI
168- MPI_Barrier (MPI_COMM_WORLD);
169- #endif
170169 t2 = std::chrono::high_resolution_clock::now ();
171170 timings[1 ].push_back (std::chrono::duration_cast<std::chrono::duration<double > >(t2 - t1).count ());
172171
173172 // Execute Add
174173 t1 = std::chrono::high_resolution_clock::now ();
175174 stream->add ();
176- #if USE_MPI
177- MPI_Barrier (MPI_COMM_WORLD);
178- #endif
179175 t2 = std::chrono::high_resolution_clock::now ();
180176 timings[2 ].push_back (std::chrono::duration_cast<std::chrono::duration<double > >(t2 - t1).count ());
181177
182178 // Execute Triad
183179 t1 = std::chrono::high_resolution_clock::now ();
184180 stream->triad ();
185- #if USE_MPI
186- MPI_Barrier (MPI_COMM_WORLD);
187- #endif
188181 t2 = std::chrono::high_resolution_clock::now ();
189182 timings[3 ].push_back (std::chrono::duration_cast<std::chrono::duration<double > >(t2 - t1).count ());
190183
191184 // Execute Dot
185+ #if USE_MPI
186+ // Synchronize ranks before computing dot-product
187+ MPI_Barrier (MPI_COMM_WORLD);
188+ #endif
192189 t1 = std::chrono::high_resolution_clock::now ();
193190 sum = stream->dot ();
194191#if USE_MPI
195- MPI_Allreduce (MPI_IN_PLACE, &sum, 1 , MPI_DOUBLE , MPI_SUM, MPI_COMM_WORLD);
192+ MPI_Allreduce (MPI_IN_PLACE, &sum, 1 , MPI_DTYPE , MPI_SUM, MPI_COMM_WORLD);
196193#endif
197194 t2 = std::chrono::high_resolution_clock::now ();
198195 timings[4 ].push_back (std::chrono::duration_cast<std::chrono::duration<double > >(t2 - t1).count ());
@@ -217,9 +214,6 @@ std::vector<std::vector<double>> run_triad(Stream<T> *stream)
217214 t1 = std::chrono::high_resolution_clock::now ();
218215 for (unsigned int k = 0 ; k < num_times; k++)
219216 {
220- #if USE_MPI
221- MPI_Barrier (MPI_COMM_WORLD);
222- #endif
223217 stream->triad ();
224218 }
225219 t2 = std::chrono::high_resolution_clock::now ();
@@ -241,14 +235,8 @@ std::vector<std::vector<double>> run_nstream(Stream<T> *stream)
241235
242236 // Run nstream in loop
243237 for (int k = 0 ; k < num_times; k++) {
244- #if USE_MPI
245- MPI_Barrier (MPI_COMM_WORLD);
246- #endif
247238 t1 = std::chrono::high_resolution_clock::now ();
248239 stream->nstream ();
249- #if USE_MPI
250- MPI_Barrier (MPI_COMM_WORLD);
251- #endif
252240 t2 = std::chrono::high_resolution_clock::now ();
253241 timings[0 ].push_back (std::chrono::duration_cast<std::chrono::duration<double > >(t2 - t1).count ());
254242 }
@@ -416,10 +404,6 @@ void run()
416404
417405
418406 stream->read_arrays (a, b, c);
419- #if USE_MPI
420- // Only check solutions on rank 0 in case verificaiton fails
421- if (rank == 0 )
422- #endif
423407 check_solution<T>(num_times, a, b, c, sum);
424408
425409 // Display timing results
@@ -485,17 +469,11 @@ void run()
485469 double max = *minmax.second ;
486470
487471#if USE_MPI
488- // Collate timings
489- if (rank == 0 )
490- {
491- MPI_Reduce (MPI_IN_PLACE, &min, 1 , MPI_DOUBLE, MPI_MIN, 0 , MPI_COMM_WORLD);
492- MPI_Reduce (MPI_IN_PLACE, &max, 1 , MPI_DOUBLE, MPI_MAX, 0 , MPI_COMM_WORLD);
493- }
494- else
495- {
496- MPI_Reduce (&min, NULL , 1 , MPI_DOUBLE, MPI_MIN, 0 , MPI_COMM_WORLD);
497- MPI_Reduce (&max, NULL , 1 , MPI_DOUBLE, MPI_MAX, 0 , MPI_COMM_WORLD);
498- }
472+ MPI_Datatype MPI_DTYPE = use_float ? MPI_FLOAT : MPI_DOUBLE;
473+
474+ // Collect global min/max timings
475+ MPI_Allreduce (MPI_IN_PLACE, &min, 1 , MPI_DTYPE, MPI_MIN, MPI_COMM_WORLD);
476+ MPI_Allreduce (MPI_IN_PLACE, &max, 1 , MPI_DTYPE, MPI_MAX, MPI_COMM_WORLD);
499477 sizes[i] *= procs;
500478#endif
501479
0 commit comments