|
Anasazi Version of the Day
|
00001 #ifndef __TSQR_MpiMessenger_hpp 00002 #define __TSQR_MpiMessenger_hpp 00003 00004 #include <mpi.h> 00005 #include <Tsqr_MessengerBase.hpp> 00006 #include <Tsqr_MpiDatatype.hpp> 00007 #include <stdexcept> 00008 00011 00012 namespace TSQR { 00013 namespace MPI { 00014 00015 template< class Datum > 00016 class MpiMessenger : public TSQR::MessengerBase< Datum > { 00017 public: 00018 MpiMessenger (MPI_Comm comm) : comm_ (comm) {} 00019 virtual ~MpiMessenger () {} 00020 00021 virtual void 00022 send (const Datum sendData[], 00023 const int sendCount, 00024 const int destProc, 00025 const int tag) 00026 { 00027 const int err = 00028 MPI_Send (const_cast< Datum* const >(sendData), sendCount, 00029 mpiType_.get(), destProc, tag, comm_); 00030 if (err != MPI_SUCCESS) 00031 throw std::runtime_error ("MPI_Send failed"); 00032 } 00033 00034 virtual void 00035 recv (Datum recvData[], 00036 const int recvCount, 00037 const int srcProc, 00038 const int tag) 00039 { 00040 MPI_Status status; 00041 const int err = MPI_Recv (recvData, recvCount, mpiType_.get(), 00042 srcProc, tag, comm_, &status); 00043 if (err != MPI_SUCCESS) 00044 throw std::runtime_error ("MPI_Recv failed"); 00045 } 00046 00047 virtual void 00048 swapData (const Datum sendData[], 00049 Datum recvData[], 00050 const int sendRecvCount, 00051 const int destProc, 00052 const int tag) 00053 { 00054 MPI_Status status; 00055 const int err = 00056 MPI_Sendrecv (const_cast< Datum* const >(sendData), sendRecvCount, 00057 mpiType_.get(), destProc, tag, 00058 recvData, sendRecvCount, mpiType_.get(), destProc, tag, 00059 comm_, &status); 00060 if (err != MPI_SUCCESS) 00061 throw std::runtime_error ("MPI_Sendrecv failed"); 00062 } 00063 00064 virtual Datum 00065 globalSum (const Datum& inDatum) 00066 { 00067 Datum input (inDatum); 00068 Datum output; 00069 00070 int count = 1; 00071 const int err = MPI_Allreduce (&input, &output, count, 00072 mpiType_.get(), MPI_SUM, comm_); 00073 if (err != MPI_SUCCESS) 00074 throw std::runtime_error ("MPI_Allreduce (MPI_SUM) failed"); 00075 return output; 00076 } 00077 00078 virtual Datum 00079 globalMin (const Datum& inDatum) 00080 { 00081 Datum input (inDatum); 00082 Datum output; 00083 00084 int count = 1; 00085 const int err = MPI_Allreduce (&input, &output, count, 00086 mpiType_.get(), MPI_MIN, comm_); 00087 if (err != MPI_SUCCESS) 00088 throw std::runtime_error ("MPI_Allreduce (MPI_MIN) failed"); 00089 return output; 00090 } 00091 00092 virtual Datum 00093 globalMax (const Datum& inDatum) 00094 { 00095 Datum input (inDatum); 00096 Datum output; 00097 00098 int count = 1; 00099 const int err = MPI_Allreduce (&input, &output, count, 00100 mpiType_.get(), MPI_MAX, comm_); 00101 if (err != MPI_SUCCESS) 00102 throw std::runtime_error ("MPI_Allreduce (MPI_MAX) failed"); 00103 return output; 00104 } 00105 00106 virtual void 00107 globalVectorSum (const Datum inData[], 00108 Datum outData[], 00109 const int count) 00110 { 00111 const int err = 00112 MPI_Allreduce (const_cast< Datum* const > (inData), outData, 00113 count, mpiType_.get(), MPI_SUM, comm_); 00114 if (err != MPI_SUCCESS) 00115 throw std::runtime_error ("MPI_Allreduce failed"); 00116 } 00117 00118 virtual void 00119 broadcast (Datum data[], 00120 const int count, 00121 const int root) 00122 { 00123 const int err = MPI_Bcast (data, count, mpiType_.get(), root, comm_); 00124 if (err != MPI_SUCCESS) 00125 throw std::runtime_error ("MPI_Bcast failed"); 00126 } 00127 00128 virtual int 00129 size() const 00130 { 00131 int nprocs = 0; 00132 const int err = MPI_Comm_size (comm_, &nprocs); 00133 if (err != MPI_SUCCESS) 00134 throw std::runtime_error ("MPI_Comm_size failed"); 00135 else if (nprocs <= 0) 00136 throw std::runtime_error ("MPI_Comm_size returned # processors <= 0"); 00137 return nprocs; 00138 } 00139 00140 virtual int 00141 rank() const 00142 { 00143 int my_rank = 0; 00144 const int err = MPI_Comm_rank (comm_, &my_rank); 00145 if (err != MPI_SUCCESS) 00146 throw std::runtime_error ("MPI_Comm_rank failed"); 00147 return my_rank; 00148 } 00149 00150 virtual void 00151 barrier () const 00152 { 00153 const int err = MPI_Barrier (comm_); 00154 if (err != MPI_SUCCESS) 00155 throw std::runtime_error ("MPI_Barrier failed"); 00156 } 00157 00158 private: 00166 mutable MPI_Comm comm_; 00167 00168 MpiDatatype< Datum > mpiType_; 00169 }; 00170 } // namespace MPI 00171 } // namespace TSQR 00172 00173 #endif // __TSQR_MpiMessenger_hpp
1.7.4