|
Teuchos Package Browser (Single Doxygen Collection) Version of the Day
|
00001 // @HEADER 00002 // *********************************************************************** 00003 // 00004 // Teuchos: Common Tools Package 00005 // Copyright (2004) Sandia Corporation 00006 // 00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive 00008 // license for use of this work by or on behalf of the U.S. Government. 00009 // 00010 // This library is free software; you can redistribute it and/or modify 00011 // it under the terms of the GNU Lesser General Public License as 00012 // published by the Free Software Foundation; either version 2.1 of the 00013 // License, or (at your option) any later version. 00014 // 00015 // This library is distributed in the hope that it will be useful, but 00016 // WITHOUT ANY WARRANTY; without even the implied warranty of 00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00018 // Lesser General Public License for more details. 00019 // 00020 // You should have received a copy of the GNU Lesser General Public 00021 // License along with this library; if not, write to the Free Software 00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00023 // USA 00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 00025 // 00026 // *********************************************************************** 00027 // @HEADER 00028 00029 #include "Teuchos_MPIComm.hpp" 00030 #include "Teuchos_ErrorPolling.hpp" 00031 00032 00033 using namespace Teuchos; 00034 00035 namespace Teuchos 00036 { 00037 const int MPIComm::INT = 1; 00038 const int MPIComm::FLOAT = 2; 00039 const int MPIComm::DOUBLE = 3; 00040 const int MPIComm::CHAR = 4; 00041 00042 const int MPIComm::SUM = 5; 00043 const int MPIComm::MIN = 6; 00044 const int MPIComm::MAX = 7; 00045 const int MPIComm::PROD = 8; 00046 } 00047 00048 00049 MPIComm::MPIComm() 00050 : 00051 #ifdef HAVE_MPI 00052 comm_(MPI_COMM_WORLD), 00053 #endif 00054 nProc_(0), myRank_(0) 00055 { 00056 init(); 00057 } 00058 00059 #ifdef HAVE_MPI 00060 MPIComm::MPIComm(MPI_Comm comm) 00061 : comm_(comm), nProc_(0), myRank_(0) 00062 { 00063 init(); 00064 } 00065 #endif 00066 00067 int MPIComm::mpiIsRunning() const 00068 { 00069 int mpiStarted = 0; 00070 #ifdef HAVE_MPI 00071 MPI_Initialized(&mpiStarted); 00072 #endif 00073 return mpiStarted; 00074 } 00075 00076 void MPIComm::init() 00077 { 00078 #ifdef HAVE_MPI 00079 00080 if (mpiIsRunning()) 00081 { 00082 errCheck(MPI_Comm_rank(comm_, &myRank_), "Comm_rank"); 00083 errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size"); 00084 } 00085 else 00086 { 00087 nProc_ = 1; 00088 myRank_ = 0; 00089 } 00090 00091 #else 00092 nProc_ = 1; 00093 myRank_ = 0; 00094 #endif 00095 } 00096 00097 #ifdef USE_MPI_GROUPS /* we're ignoring groups for now */ 00098 00099 MPIComm::MPIComm(const MPIComm& parent, const MPIGroup& group) 00100 : 00101 #ifdef HAVE_MPI 00102 comm_(MPI_COMM_WORLD), 00103 #endif 00104 nProc_(0), myRank_(0) 00105 { 00106 #ifdef HAVE_MPI 00107 if (group.getNProc()==0) 00108 { 00109 rank_ = -1; 00110 nProc_ = 0; 00111 } 00112 else if (parent.containsMe()) 00113 { 00114 MPI_Comm parentComm = parent.comm_; 00115 MPI_Group newGroup = group.group_; 00116 00117 errCheck(MPI_Comm_create(parentComm, newGroup, &comm_), 00118 "Comm_create"); 00119 00120 if (group.containsProc(parent.getRank())) 00121 { 00122 errCheck(MPI_Comm_rank(comm_, &rank_), "Comm_rank"); 00123 00124 errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size"); 00125 } 00126 else 00127 { 00128 rank_ = -1; 00129 nProc_ = -1; 00130 return; 00131 } 00132 } 00133 else 00134 { 00135 rank_ = -1; 00136 nProc_ = -1; 00137 } 00138 #endif 00139 } 00140 00141 #endif /* USE_MPI_GROUPS */ 00142 00143 MPIComm& MPIComm::world() 00144 { 00145 static MPIComm w = MPIComm(); 00146 return w; 00147 } 00148 00149 00150 MPIComm& MPIComm::self() 00151 { 00152 #ifdef HAVE_MPI 00153 static MPIComm w = MPIComm(MPI_COMM_SELF); 00154 #else 00155 static MPIComm w = MPIComm(); 00156 #endif 00157 return w; 00158 } 00159 00160 00161 void MPIComm::synchronize() const 00162 { 00163 #ifdef HAVE_MPI 00164 //mutex_.lock(); 00165 { 00166 if (mpiIsRunning()) 00167 { 00168 /* test whether errors have been detected on another proc before 00169 * doing the collective operation. */ 00170 TEUCHOS_POLL_FOR_FAILURES(*this); 00171 /* if we're to this point, all processors are OK */ 00172 00173 errCheck(::MPI_Barrier(comm_), "Barrier"); 00174 } 00175 } 00176 //mutex_.unlock(); 00177 #endif 00178 } 00179 00180 void MPIComm::allToAll(void* sendBuf, int sendCount, int sendType, 00181 void* recvBuf, int recvCount, int recvType) const 00182 { 00183 #ifdef HAVE_MPI 00184 //mutex_.lock(); 00185 { 00186 MPI_Datatype mpiSendType = getDataType(sendType); 00187 MPI_Datatype mpiRecvType = getDataType(recvType); 00188 00189 00190 if (mpiIsRunning()) 00191 { 00192 /* test whether errors have been detected on another proc before 00193 * doing the collective operation. */ 00194 TEUCHOS_POLL_FOR_FAILURES(*this); 00195 /* if we're to this point, all processors are OK */ 00196 00197 errCheck(::MPI_Alltoall(sendBuf, sendCount, mpiSendType, 00198 recvBuf, recvCount, mpiRecvType, 00199 comm_), "Alltoall"); 00200 } 00201 } 00202 //mutex_.unlock(); 00203 #else 00204 (void)sendBuf; 00205 (void)sendCount; 00206 (void)sendType; 00207 (void)recvBuf; 00208 (void)recvCount; 00209 (void)recvType; 00210 #endif 00211 } 00212 00213 void MPIComm::allToAllv(void* sendBuf, int* sendCount, 00214 int* sendDisplacements, int sendType, 00215 void* recvBuf, int* recvCount, 00216 int* recvDisplacements, int recvType) const 00217 { 00218 #ifdef HAVE_MPI 00219 //mutex_.lock(); 00220 { 00221 MPI_Datatype mpiSendType = getDataType(sendType); 00222 MPI_Datatype mpiRecvType = getDataType(recvType); 00223 00224 if (mpiIsRunning()) 00225 { 00226 /* test whether errors have been detected on another proc before 00227 * doing the collective operation. */ 00228 TEUCHOS_POLL_FOR_FAILURES(*this); 00229 /* if we're to this point, all processors are OK */ 00230 00231 errCheck(::MPI_Alltoallv(sendBuf, sendCount, sendDisplacements, mpiSendType, 00232 recvBuf, recvCount, recvDisplacements, mpiRecvType, 00233 comm_), "Alltoallv"); 00234 } 00235 } 00236 //mutex_.unlock(); 00237 #else 00238 (void)sendBuf; 00239 (void)sendCount; 00240 (void)sendDisplacements; 00241 (void)sendType; 00242 (void)recvBuf; 00243 (void)recvCount; 00244 (void)recvDisplacements; 00245 (void)recvType; 00246 #endif 00247 } 00248 00249 void MPIComm::gather(void* sendBuf, int sendCount, int sendType, 00250 void* recvBuf, int recvCount, int recvType, 00251 int root) const 00252 { 00253 #ifdef HAVE_MPI 00254 //mutex_.lock(); 00255 { 00256 MPI_Datatype mpiSendType = getDataType(sendType); 00257 MPI_Datatype mpiRecvType = getDataType(recvType); 00258 00259 00260 if (mpiIsRunning()) 00261 { 00262 /* test whether errors have been detected on another proc before 00263 * doing the collective operation. */ 00264 TEUCHOS_POLL_FOR_FAILURES(*this); 00265 /* if we're to this point, all processors are OK */ 00266 00267 errCheck(::MPI_Gather(sendBuf, sendCount, mpiSendType, 00268 recvBuf, recvCount, mpiRecvType, 00269 root, comm_), "Gather"); 00270 } 00271 } 00272 //mutex_.unlock(); 00273 #endif 00274 } 00275 00276 void MPIComm::gatherv(void* sendBuf, int sendCount, int sendType, 00277 void* recvBuf, int* recvCount, int* displacements, int recvType, 00278 int root) const 00279 { 00280 #ifdef HAVE_MPI 00281 //mutex_.lock(); 00282 { 00283 MPI_Datatype mpiSendType = getDataType(sendType); 00284 MPI_Datatype mpiRecvType = getDataType(recvType); 00285 00286 if (mpiIsRunning()) 00287 { 00288 /* test whether errors have been detected on another proc before 00289 * doing the collective operation. */ 00290 TEUCHOS_POLL_FOR_FAILURES(*this); 00291 /* if we're to this point, all processors are OK */ 00292 00293 errCheck(::MPI_Gatherv(sendBuf, sendCount, mpiSendType, 00294 recvBuf, recvCount, displacements, mpiRecvType, 00295 root, comm_), "Gatherv"); 00296 } 00297 } 00298 //mutex_.unlock(); 00299 #endif 00300 } 00301 00302 void MPIComm::allGather(void* sendBuf, int sendCount, int sendType, 00303 void* recvBuf, int recvCount, 00304 int recvType) const 00305 { 00306 #ifdef HAVE_MPI 00307 //mutex_.lock(); 00308 { 00309 MPI_Datatype mpiSendType = getDataType(sendType); 00310 MPI_Datatype mpiRecvType = getDataType(recvType); 00311 00312 if (mpiIsRunning()) 00313 { 00314 /* test whether errors have been detected on another proc before 00315 * doing the collective operation. */ 00316 TEUCHOS_POLL_FOR_FAILURES(*this); 00317 /* if we're to this point, all processors are OK */ 00318 00319 errCheck(::MPI_Allgather(sendBuf, sendCount, mpiSendType, 00320 recvBuf, recvCount, 00321 mpiRecvType, comm_), 00322 "AllGather"); 00323 } 00324 } 00325 //mutex_.unlock(); 00326 #endif 00327 } 00328 00329 00330 void MPIComm::allGatherv(void* sendBuf, int sendCount, int sendType, 00331 void* recvBuf, int* recvCount, 00332 int* recvDisplacements, 00333 int recvType) const 00334 { 00335 #ifdef HAVE_MPI 00336 //mutex_.lock(); 00337 { 00338 MPI_Datatype mpiSendType = getDataType(sendType); 00339 MPI_Datatype mpiRecvType = getDataType(recvType); 00340 00341 if (mpiIsRunning()) 00342 { 00343 /* test whether errors have been detected on another proc before 00344 * doing the collective operation. */ 00345 TEUCHOS_POLL_FOR_FAILURES(*this); 00346 /* if we're to this point, all processors are OK */ 00347 00348 errCheck(::MPI_Allgatherv(sendBuf, sendCount, mpiSendType, 00349 recvBuf, recvCount, recvDisplacements, 00350 mpiRecvType, 00351 comm_), 00352 "AllGatherv"); 00353 } 00354 } 00355 //mutex_.unlock(); 00356 #endif 00357 } 00358 00359 00360 void MPIComm::bcast(void* msg, int length, int type, int src) const 00361 { 00362 #ifdef HAVE_MPI 00363 //mutex_.lock(); 00364 { 00365 if (mpiIsRunning()) 00366 { 00367 /* test whether errors have been detected on another proc before 00368 * doing the collective operation. */ 00369 TEUCHOS_POLL_FOR_FAILURES(*this); 00370 /* if we're to this point, all processors are OK */ 00371 00372 MPI_Datatype mpiType = getDataType(type); 00373 errCheck(::MPI_Bcast(msg, length, mpiType, src, 00374 comm_), "Bcast"); 00375 } 00376 } 00377 //mutex_.unlock(); 00378 #endif 00379 } 00380 00381 void MPIComm::allReduce(void* input, void* result, int inputCount, 00382 int type, int op) const 00383 { 00384 #ifdef HAVE_MPI 00385 00386 //mutex_.lock(); 00387 { 00388 MPI_Op mpiOp = getOp(op); 00389 MPI_Datatype mpiType = getDataType(type); 00390 00391 if (mpiIsRunning()) 00392 { 00393 errCheck(::MPI_Allreduce(input, result, inputCount, mpiType, 00394 mpiOp, comm_), 00395 "Allreduce"); 00396 } 00397 } 00398 //mutex_.unlock(); 00399 #endif 00400 } 00401 00402 00403 #ifdef HAVE_MPI 00404 00405 MPI_Datatype MPIComm::getDataType(int type) 00406 { 00407 TEST_FOR_EXCEPTION( 00408 !(type == INT || type==FLOAT 00409 || type==DOUBLE || type==CHAR), 00410 std::range_error, 00411 "invalid type " << type << " in MPIComm::getDataType"); 00412 00413 if(type == INT) return MPI_INT; 00414 if(type == FLOAT) return MPI_FLOAT; 00415 if(type == DOUBLE) return MPI_DOUBLE; 00416 00417 return MPI_CHAR; 00418 } 00419 00420 00421 void MPIComm::errCheck(int errCode, const std::string& methodName) 00422 { 00423 TEST_FOR_EXCEPTION(errCode != 0, std::runtime_error, 00424 "MPI function MPI_" << methodName 00425 << " returned error code=" << errCode); 00426 } 00427 00428 MPI_Op MPIComm::getOp(int op) 00429 { 00430 00431 TEST_FOR_EXCEPTION( 00432 !(op == SUM || op==MAX 00433 || op==MIN || op==PROD), 00434 std::range_error, 00435 "invalid operator " 00436 << op << " in MPIComm::getOp"); 00437 00438 if( op == SUM) return MPI_SUM; 00439 else if( op == MAX) return MPI_MAX; 00440 else if( op == MIN) return MPI_MIN; 00441 return MPI_PROD; 00442 } 00443 00444 #endif
1.7.4