AFEPack
MPI.h
浏览该文件的文档。
00001 
00046 #ifndef __MPI_h__
00047 #define __MPI_h__
00048 
00049 #include <string>
00050 #include <vector>
00051 #include <list>
00052 #include <map>
00053 #include <mpi.h>
00054 
00055 #include <AFEPack/DerefIterator.h>
00056 #include <AFEPack/BinaryBuffer.h>
00057 
00058 namespace MPI {
00059 
00063   int get_random_tag(MPI_Comm comm);
00064 
00073   template <class DataIterator, class TargetIterator>
00074     void sendrecv_data(MPI_Comm comm, 
00075                        int n,
00076                        DataIterator start_send_data,
00077                        DataIterator start_recv_data,
00078                        TargetIterator start_target) {
00079     int tag = get_random_tag(comm);
00080 
00081     MPI_Request request[2*n];
00082     MPI_Status status[2*n];
00083 
00085     int send_data_size[n], recv_data_size[n];
00086     DataIterator the_send_data = start_send_data;
00087     for (int i = 0;i < n;++ i, ++ the_send_data) {
00088       send_data_size[i] = the_send_data->size();
00089     }
00090 
00091     TargetIterator the_target = start_target;
00092     for (int i = 0;i < n;++ i, ++ the_target) {
00093       MPI_Isend(&send_data_size[i], 1, MPI_INT,
00094                 *the_target, tag, comm, &request[i]);
00095       MPI_Irecv(&recv_data_size[i], 1, MPI_INT,
00096                 *the_target, tag, comm, &request[i + n]);
00097     }
00098     MPI_Waitall(2*n, request, status);
00099 
00101     DataIterator the_recv_data = start_recv_data;
00102     for (int i = 0;i < n;++ i, ++ the_recv_data) {
00103       the_recv_data->resize(recv_data_size[i]);
00104     }
00105 
00107     int n_request = 0;
00108     the_target = start_target;
00109     the_send_data = start_send_data;
00110     the_recv_data = start_recv_data;
00111     for (int i = 0;i < n;++ i) {
00112       if (send_data_size[i] > 0) {
00113         MPI_Isend(the_send_data->start_address(), send_data_size[i], MPI_CHAR,
00114                   *the_target, tag, comm, &request[n_request ++]);
00115       }
00116       if (recv_data_size[i] > 0) {
00117         MPI_Irecv(the_recv_data->start_address(), recv_data_size[i], MPI_CHAR,
00118                   *the_target, tag, comm, &request[n_request ++]);
00119       }
00120       ++ the_target, ++ the_send_data, ++ the_recv_data;
00121     }
00122     MPI_Waitall(n_request, request, status);
00123   }
00124 
00128   template <class T>
00129     struct Remote_pointer {
00130       int type; 
00131       T * ptr;  
00132     Remote_pointer() : type(0), ptr(NULL) {}
00133     Remote_pointer(int _type, T * _ptr) :
00134       type(_type), ptr(_ptr) {}
00135     Remote_pointer(const Remote_pointer<T>& rp) :
00136       type(rp.type), ptr(rp.ptr) {}
00137       Remote_pointer<T>& operator=(const Remote_pointer<T>& rp) {
00138         type = rp.type;
00139         ptr = rp.ptr;
00140       }
00141     };
00142 
00143   namespace Shared_type_filter {
00144     struct all {
00145       bool operator()(int type) const {
00146         return true;
00147       }
00148     };
00149     template <int D0, int D1>
00150       struct between {
00151         bool operator()(int type) const {
00152           return (type >= D0)&&(type < D1);
00153         }
00154       };
00155     template <int D>
00156       struct only {
00157         bool operator()(int type) const {
00158           return (type == D);
00159         }
00160       };
00161     template <int D>
00162       struct except {
00163         bool operator()(int type) const {
00164           return (type != D);
00165         }
00166       };
00167     template <int D>
00168       struct greater_than {
00169         bool operator()(int type) const {
00170           return (type > D);
00171         }
00172       };
00173     template <int D>
00174       struct less_than {
00175         bool operator()(int type) const {
00176           return (type < D);
00177         }
00178       };
00179     template <class FILTER>
00180       struct negate {
00181         FILTER filter;
00182 
00183         negate() {}
00184         negate(const FILTER& _filter) : filter(_filter) {}
00185 
00186         bool operator()(int type) const {
00187           if (filter(type)) return false;
00188           else return true;
00189         }
00190       };
00191   }
00192 
00201   template <class T>
00202     struct Shared_object : public std::multimap<int,Remote_pointer<T> > {
00203     typedef Remote_pointer<T> pointer_t;
00204     typedef std::pair<int,pointer_t> pair_t;
00205     typedef std::multimap<int,pointer_t> _Base;
00206 
00207     T * _ptr; 
00208 
00209     Shared_object() {}
00210     Shared_object(T& t) : _ptr(&t) {}
00211     void add_clone(int rank, T* ptr) {
00212       this->add_clone(rank, pointer_t(0, ptr));
00213     }
00214     void add_clone(int rank, int type, T* ptr) {
00215       this->add_clone(rank, pointer_t(type, ptr));
00216     }
00217     void add_clone(int rank, const pointer_t& ptr) {
00218       if (! is_duplicate_entry(rank, ptr)) {
00219         this->insert(pair_t(rank, ptr));
00220       }
00221     }
00222 
00223     T *& local_pointer() { return _ptr; }
00224     T * local_pointer() const { return _ptr; }
00225     T& local_object() const { return *_ptr; }
00226 
00230     bool is_duplicate_entry(int rank, 
00231                             const pointer_t& ptr) const {
00232       typedef typename _Base::const_iterator it_t;
00233       std::pair<it_t,it_t> range = _Base::equal_range(rank);
00234       it_t the_ptr = range.first, end_ptr = range.second;
00235       for (;the_ptr != end_ptr;++ the_ptr) {
00236         if (the_ptr->second.ptr == ptr.ptr) {
00237           return true;
00238         }
00239       }
00240       return false;
00241     }
00242 
00244 
00253     int primary_rank(int rank) const {
00254       return std::min(_Base::begin()->first, rank);
00255     }
00256 
00262     bool is_on_primary_rank(int rank) const {
00263       return (_Base::begin()->first >= rank);
00264     }
00266 
00275     bool is_primary_object(int rank) const {
00276       int first_rank = _Base::begin()->first;
00277       bool result = true;
00278       if (first_rank < rank) {
00279         result = false; 
00280       } else if (first_rank == rank) { 
00284         typedef typename _Base::const_iterator it_t;
00285         it_t the_ptr = _Base::begin();
00286         it_t end_ptr = _Base::upper_bound(rank);
00287         for (;the_ptr != end_ptr;++ the_ptr) {
00288           assert (the_ptr->second.ptr != _ptr);
00289           if (the_ptr->second.ptr < _ptr) {
00290             result = false;
00291             break;
00292           }
00293         }
00294       } else { 
00295         result = true; 
00296       }
00297       return result;
00298     }
00299   };
00300 
00305   template <class T, 
00306     template <class C, typename ALLOC = std::allocator<C> > class CNT = std::list>
00307     struct Shared_list : public CNT<Shared_object<T> > {};
00308 
00313   template <class T,
00314     template <class C, typename ALLOC = std::allocator<C> > class CNT = std::list>
00315     struct Shared_ptr_list : public CNT<Shared_object<T> *> {
00316     typedef CNT<Shared_object<T> *> base_t;
00317     typedef _Deref_iterator<typename base_t::iterator, Shared_object<T> > iterator;
00318     typedef _Deref_iterator<typename base_t::const_iterator, const Shared_object<T> > const_iterator;
00319     iterator begin() { return base_t::begin(); }
00320     iterator end() { return base_t::end(); }
00321     const_iterator begin() const { return base_t::begin(); }
00322     const_iterator end() const { return base_t::end(); }
00323     typename base_t::iterator begin_ptr() { return base_t::begin(); }
00324     typename base_t::iterator end_ptr() { return base_t::end(); }
00325     typename base_t::const_iterator begin_ptr() const { return base_t::begin(); }
00326     typename base_t::const_iterator end_ptr() const { return base_t::end(); }
00327   };
00328 
00335   template <class T, class SHARED_TYPE_FILTER=Shared_type_filter::all>
00336     struct Transmit_map : 
00337     public std::map<int, std::pair<int, std::list<std::pair<T*, T*> > > > {
00338     typedef std::list<std::pair<T*, T*> > value_t;
00339     typedef std::pair<int, value_t> pair_t;
00340     typedef std::map<int, pair_t> _Base;
00341     typedef SHARED_TYPE_FILTER type_filter_t;
00342 
00343     type_filter_t type_filter;
00344 
00349     template <class CONTAINER>
00350       void build(const CONTAINER& shlist) {
00351       _Base::clear();
00352 
00353       typename CONTAINER::const_iterator 
00354         the_obj = shlist.begin(),
00355         end_obj = shlist.end();
00356       for (;the_obj != end_obj;++ the_obj) {
00357         this->add_object(*the_obj);
00358       }
00359     }
00360 
00361     template <class CONTAINER>
00362       void build(const CONTAINER& shlist,
00363                  bool (*filter)(T *)) {
00364       _Base::clear();
00365 
00366       typename CONTAINER::const_iterator 
00367         the_obj = shlist.begin(),
00368         end_obj = shlist.end();
00369       for (;the_obj != end_obj;++ the_obj) {
00370         this->add_object(*the_obj, 
00371                          (*filter)(the_obj->local_pointer()));
00372       }
00373     }
00374 
00375     template <class CONTAINER, class DATA_PACKER>
00376       void build(const CONTAINER& shlist,
00377                  DATA_PACKER& data_packer,
00378                  bool (DATA_PACKER::*filter)(T *)) {
00379       _Base::clear();
00380 
00381       typename CONTAINER::const_iterator 
00382         the_obj = shlist.begin(),
00383         end_obj = shlist.end();
00384       for (;the_obj != end_obj;++ the_obj) {
00385         this->add_object(*the_obj, 
00386                          (data_packer.*filter)(the_obj->local_pointer()));
00387       }
00388     }
00389 
00390     template <class CONTAINER, class DATA_PACKER>
00391       void build(const CONTAINER& shlist,
00392                  const DATA_PACKER& data_packer,
00393                  bool (DATA_PACKER::*filter)(T *) const) {
00394       _Base::clear();
00395 
00396       typename CONTAINER::const_iterator 
00397         the_obj = shlist.begin(),
00398         end_obj = shlist.end();
00399       for (;the_obj != end_obj;++ the_obj) {
00400         this->add_object(*the_obj, 
00401                          (data_packer.*filter)(the_obj->local_pointer()));
00402       }
00403     }
00404 
00409     template <class ITERATOR>
00410       void build(ITERATOR& begin, ITERATOR& end) {
00411       _Base::clear();
00412 
00413       ITERATOR the_obj(begin);
00414       for (;the_obj != end;++ the_obj) {
00415         this->add_object(*the_obj);
00416       }
00417     }
00418 
00422     void add_object(const Shared_object<T>& obj, 
00423                     bool is_add_entry = true) {
00424       typename Shared_object<T>::const_iterator
00425         the_ptr = obj.begin(),
00426         end_ptr = obj.end();
00427       if (! is_add_entry) {
00431         for (;the_ptr != end_ptr;++ the_ptr) {
00432           int rank = the_ptr->first;
00433           if (this->find(rank) == this->end()) {
00434             (*this)[rank] = pair_t(0, value_t());
00435           }
00436         }
00437       } else {
00438         T * obj_ptr = obj.local_pointer();
00439         for (;the_ptr != end_ptr;++ the_ptr) {
00440           const int& rank = the_ptr->first;
00441           if (this->find(rank) == this->end()) {
00442             (*this)[rank] = pair_t(0, value_t());
00443           }
00444 
00445           if (type_filter(the_ptr->second.type)) {
00446             pair_t& pair = (*this)[rank];
00447             pair.first += 1;
00448             pair.second.push_back(std::pair<T*,T*>(obj_ptr, 
00449                                                    the_ptr->second.ptr));
00450           }
00451         }
00452       }
00453     }
00454   };
00455 
00476   template <class T, class DATA_PACKER, class SHARED_TYPE_FILTER>
00477     void sync_data(MPI_Comm comm,
00478                    Transmit_map<T,SHARED_TYPE_FILTER>& map,
00479                    DATA_PACKER& data_packer,
00480                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00481                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&)) {
00482     typedef Transmit_map<T,SHARED_TYPE_FILTER> map_t;
00483     typedef typename map_t::value_t value_t;
00484 
00485     std::list<int> target_list;
00486     std::list<BinaryBuffer<> > data_buffer_in, data_buffer_out;
00487 
00488     int n = 0;
00489     typename map_t::iterator
00490       the_pair = map.begin(),
00491       end_pair = map.end();
00492     for (;the_pair != end_pair;++ the_pair, ++ n) {
00493       int rank = the_pair->first;
00494       target_list.push_back(rank);
00495 
00496       data_buffer_in.push_back(BinaryBuffer<>());
00497       data_buffer_out.push_back(BinaryBuffer<>());
00498 
00499       AFEPack::ostream<> os(data_buffer_out.back());
00500       int n_item = the_pair->second.first;
00501       if (n_item == 0) continue;
00502 
00503       value_t& lst = the_pair->second.second;
00504       os << n_item; 
00505       typename value_t::iterator
00506         the_ptr = lst.begin(), end_ptr = lst.end();
00507       for (;the_ptr != end_ptr;++ the_ptr) {
00508         T *& local_obj = the_ptr->first;
00509         T *& remote_obj = the_ptr->second;
00510         os << remote_obj;
00511         (data_packer.*pack)(local_obj, rank, os);
00512       }
00513     }
00514 
00515     sendrecv_data(comm, n, data_buffer_out.begin(), data_buffer_in.begin(),
00516                   target_list.begin());
00517 
00518     typename std::list<BinaryBuffer<> >::iterator 
00519       the_buf = data_buffer_in.begin();
00520     the_pair = map.begin();
00521     for (;the_pair != end_pair;++ the_pair, ++ the_buf) {
00522       if (the_buf->size() == 0) continue;
00523 
00524       int rank = the_pair->first;
00525       AFEPack::istream<> is(*the_buf);
00526       int n_item;
00527       T * local_obj;
00528       is >> n_item; 
00529       for (int i = 0;i < n_item;++ i) {
00530         is >> local_obj;
00531         (data_packer.*unpack)(local_obj, rank, is);
00532       }
00533     }
00534   }
00535 
00546   template <class T, class SHARED_LIST, class DATA_PACKER>
00547     void sync_data(MPI_Comm comm,
00548                    SHARED_LIST& shlist,
00549                    DATA_PACKER& data_packer,
00550                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00551                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&)) {
00552     Transmit_map<T> map;
00553     map.build(shlist);
00554     sync_data(comm, map, data_packer, pack, unpack);
00555   }
00556 
00557   template <class T, class SHARED_LIST, class DATA_PACKER, class SHARED_TYPE_FILTER>
00558     void sync_data(MPI_Comm comm,
00559                    SHARED_LIST& shlist,
00560                    DATA_PACKER& data_packer,
00561                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00562                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&),
00563                    const SHARED_TYPE_FILTER& stf) {
00564     Transmit_map<T,SHARED_TYPE_FILTER> map;
00565     map.build(shlist);
00566     sync_data(comm, map, data_packer, pack, unpack);
00567   }
00568 
00572   template <class T, class SHARED_LIST, class DATA_PACKER>
00573     void sync_data(MPI_Comm comm,
00574                    SHARED_LIST& shlist,
00575                    DATA_PACKER& data_packer,
00576                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00577                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&),
00578                    bool (DATA_PACKER::*filter)(T *)) {
00579     Transmit_map<T> map;
00580     map.build(shlist, data_packer, filter);
00581     sync_data(comm, map, data_packer, pack, unpack);
00582   }
00583 
00584   template <class T, class SHARED_LIST, class DATA_PACKER, class SHARED_TYPE_FILTER>
00585     void sync_data(MPI_Comm comm,
00586                    SHARED_LIST& shlist,
00587                    DATA_PACKER& data_packer,
00588                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00589                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&),
00590                    bool (DATA_PACKER::*filter)(T *),
00591                    const SHARED_TYPE_FILTER& stf) {
00592     Transmit_map<T,SHARED_TYPE_FILTER> map;
00593     map.build(shlist, data_packer, filter);
00594     sync_data(comm, map, data_packer, pack, unpack);
00595   }
00596 
00600   template <class T, class SHARED_LIST, class DATA_PACKER>
00601     void sync_data(MPI_Comm comm,
00602                    SHARED_LIST& shlist,
00603                    DATA_PACKER& data_packer,
00604                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00605                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&),
00606                    bool (DATA_PACKER::*filter)(T *) const) {
00607     Transmit_map<T> map;
00608     map.build(shlist, data_packer, filter);
00609     sync_data(comm, map, data_packer, pack, unpack);
00610   }
00611 
00612   template <class T, class SHARED_LIST, class DATA_PACKER, class SHARED_TYPE_FILTER>
00613     void sync_data(MPI_Comm comm,
00614                    SHARED_LIST& shlist,
00615                    DATA_PACKER& data_packer,
00616                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00617                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&),
00618                    bool (DATA_PACKER::*filter)(T *) const,
00619                    const SHARED_TYPE_FILTER& stf) {
00620     Transmit_map<T,SHARED_TYPE_FILTER> map;
00621     map.build(shlist, data_packer, filter);
00622     sync_data(comm, map, data_packer, pack, unpack);
00623   }
00624 
00625 } // namespace MPI
00626 
00627 #endif // __MPI_h__
00628