|
Anasazi Version of the Day
|
00001 #ifndef __TSQR_DistTsqrRB_hpp 00002 #define __TSQR_DistTsqrRB_hpp 00003 00004 #include <Tsqr_ApplyType.hpp> 00005 #include <Tsqr_Combine.hpp> 00006 #include <Tsqr_Matrix.hpp> 00007 #include <Tsqr_ScalarTraits.hpp> 00008 #include <Tsqr_StatTimeMonitor.hpp> 00009 00010 #include <algorithm> 00011 #include <sstream> 00012 #include <stdexcept> 00013 #include <utility> 00014 #include <vector> 00015 00018 00019 namespace TSQR { 00020 00028 template< class LocalOrdinal, class Scalar > 00029 class DistTsqrRB { 00030 public: 00031 typedef LocalOrdinal ordinal_type; 00032 typedef Scalar scalar_type; 00033 typedef typename ScalarTraits< scalar_type >::magnitude_type magnitude_type; 00034 typedef MatView< ordinal_type, scalar_type > matview_type; 00035 typedef Matrix< ordinal_type, scalar_type > matrix_type; 00036 typedef int rank_type; 00037 typedef Combine< ordinal_type, scalar_type > combine_type; 00038 00043 DistTsqrRB (const Teuchos::RCP< MessengerBase< scalar_type > >& messenger) : 00044 messenger_ (messenger), 00045 totalTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::factorExplicit() total time")), 00046 reduceCommTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::factorReduce() communication time")), 00047 reduceTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::factorReduce() total time")), 00048 bcastCommTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::explicitQBroadcast() communication time")), 00049 bcastTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::explicitQBroadcast() total time")) 00050 {} 00051 00055 void 00056 getStats (std::vector< TimeStats >& stats) const 00057 { 00058 const int numTimers = 5; 00059 stats.resize (std::max (stats.size(), static_cast<size_t>(numTimers))); 00060 00061 stats[0] = totalStats_; 00062 stats[1] = reduceCommStats_; 00063 stats[2] = reduceStats_; 00064 stats[3] = bcastCommStats_; 00065 stats[4] = bcastStats_; 00066 } 00067 00071 void 00072 getStatsLabels (std::vector< std::string >& labels) const 00073 { 00074 const int numTimers = 5; 00075 labels.resize (std::max (labels.size(), static_cast<size_t>(numTimers))); 00076 00077 labels[0] = totalTime_->name(); 00078 labels[1] = reduceCommTime_->name(); 00079 labels[2] = reduceTime_->name(); 00080 labels[3] = bcastCommTime_->name(); 00081 labels[4] = bcastTime_->name(); 00082 } 00083 00086 bool QR_produces_R_factor_with_nonnegative_diagonal () const { 00087 return combine_type::QR_produces_R_factor_with_nonnegative_diagonal(); 00088 } 00089 00106 void 00107 factorExplicit (matview_type R_mine, matview_type Q_mine) 00108 { 00109 StatTimeMonitor totalMonitor (*totalTime_, totalStats_); 00110 00111 // Dimension sanity checks. R_mine should have at least as many 00112 // rows as columns (since we will be working on the upper 00113 // triangle). Q_mine should have the same number of rows as 00114 // R_mine has columns, but Q_mine may have any number of 00115 // columns. (It depends on how many columns of the explicit Q 00116 // factor we want to compute.) 00117 if (R_mine.nrows() < R_mine.ncols()) 00118 { 00119 std::ostringstream os; 00120 os << "R factor input has fewer rows (" << R_mine.nrows() 00121 << ") than columns (" << R_mine.ncols() << ")"; 00122 // This is a logic error because TSQR users should not be 00123 // calling this method directly. 00124 throw std::logic_error (os.str()); 00125 } 00126 else if (Q_mine.nrows() != R_mine.ncols()) 00127 { 00128 std::ostringstream os; 00129 os << "Q factor input must have the same number of rows as the R " 00130 "factor input has columns. Q has " << Q_mine.nrows() 00131 << " rows, but R has " << R_mine.ncols() << " columns."; 00132 // This is a logic error because TSQR users should not be 00133 // calling this method directly. 00134 throw std::logic_error (os.str()); 00135 } 00136 00137 // The factorization is a recursion over processors [P_first, P_last]. 00138 const rank_type P_mine = messenger_->rank(); 00139 const rank_type P_first = 0; 00140 const rank_type P_last = messenger_->size() - 1; 00141 00142 // Intermediate Q factors are stored implicitly. QFactors[k] is 00143 // an upper triangular matrix of Householder reflectors, and 00144 // tauArrays[k] contains its corresponding scaling factors (TAU, 00145 // in LAPACK notation). These two arrays will be filled in by 00146 // factorReduce(). Different MPI processes will have different 00147 // numbers of elements in these arrays. In fact, on some 00148 // processes these arrays may be empty on output. This is a 00149 // feature, not a bug! 00150 // 00151 // Even though QFactors and tauArrays have the same type has the 00152 // first resp. second elements of DistTsqr::FactorOutput, they 00153 // are not compatible with the output of DistTsqr::factor() and 00154 // cannot be used as the input to DistTsqr::apply() or 00155 // DistTsqr::explicit_Q(). This is because factor() computes a 00156 // general factorization suitable for applying Q (or Q^T or Q^*) 00157 // to any compatible matrix, whereas factorExplicit() computes a 00158 // factorization specifically for the purpose of forming the 00159 // explicit Q factor. The latter lets us use a broadcast to 00160 // compute Q, rather than a more message-intensive all-to-all 00161 // (butterfly). 00162 std::vector< matrix_type > QFactors; 00163 std::vector< std::vector< scalar_type > > tauArrays; 00164 00165 { 00166 StatTimeMonitor reduceMonitor (*reduceTime_, reduceStats_); 00167 factorReduce (R_mine, P_mine, P_first, P_last, QFactors, tauArrays); 00168 } 00169 00170 if (QFactors.size() != tauArrays.size()) 00171 { 00172 std::ostringstream os; 00173 os << "QFactors and tauArrays should have the same number of element" 00174 "s after factorReduce() returns, but they do not. QFactors has " 00175 << QFactors.size() << " elements, but tauArrays has " 00176 << tauArrays.size() << " elements."; 00177 throw std::logic_error (os.str()); 00178 } 00179 00180 Q_mine.fill (scalar_type (0)); 00181 if (messenger_->rank() == 0) 00182 { 00183 for (ordinal_type j = 0; j < Q_mine.ncols(); ++j) 00184 Q_mine(j, j) = scalar_type (1); 00185 } 00186 // Scratch space for computing results to send to other processors. 00187 matrix_type Q_other (Q_mine.nrows(), Q_mine.ncols(), scalar_type (0)); 00188 const rank_type numSteps = QFactors.size() - 1; 00189 00190 { 00191 StatTimeMonitor bcastMonitor (*bcastTime_, bcastStats_); 00192 explicitQBroadcast (R_mine, Q_mine, Q_other.view(), 00193 P_mine, P_first, P_last, 00194 numSteps, QFactors, tauArrays); 00195 } 00196 } 00197 00198 private: 00199 00200 void 00201 factorReduce (matview_type R_mine, 00202 const rank_type P_mine, 00203 const rank_type P_first, 00204 const rank_type P_last, 00205 std::vector< matrix_type >& QFactors, 00206 std::vector< std::vector< scalar_type > >& tauArrays) 00207 { 00208 if (P_last < P_first) 00209 { 00210 std::ostringstream os; 00211 os << "Programming error in factorReduce() recursion: interval " 00212 "[P_first, P_last] is invalid: P_first = " << P_first 00213 << ", P_last = " << P_last << "."; 00214 throw std::logic_error (os.str()); 00215 } 00216 else if (P_mine < P_first || P_mine > P_last) 00217 { 00218 std::ostringstream os; 00219 os << "Programming error in factorReduce() recursion: P_mine (= " 00220 << P_mine << ") is not in current process rank interval " 00221 << "[P_first = " << P_first << ", P_last = " << P_last << "]"; 00222 throw std::logic_error (os.str()); 00223 } 00224 else if (P_last == P_first) 00225 return; // skip singleton intervals (see explanation below) 00226 else 00227 { 00228 // Recurse on two intervals: [P_first, P_mid-1] and [P_mid, 00229 // P_last]. For example, if [P_first, P_last] = [0, 9], 00230 // P_mid = floor( (0+9+1)/2 ) = 5 and the intervals are 00231 // [0,4] and [5,9]. 00232 // 00233 // If [P_first, P_last] = [4,6], P_mid = floor( (4+6+1)/2 ) 00234 // = 5 and the intervals are [4,4] (a singleton) and [5,6]. 00235 // The latter case shows that singleton intervals may arise. 00236 // We treat them as a base case in the recursion. Process 4 00237 // won't be skipped completely, though; it will get combined 00238 // with the result from [5,6]. 00239 00240 // Adding 1 and doing integer division works like "ceiling." 00241 const rank_type P_mid = (P_first + P_last + 1) / 2; 00242 00243 if (P_mine < P_mid) // Interval [P_first, P_mid-1] 00244 factorReduce (R_mine, P_mine, P_first, P_mid - 1, 00245 QFactors, tauArrays); 00246 else // Interval [P_mid, P_last] 00247 factorReduce (R_mine, P_mine, P_mid, P_last, 00248 QFactors, tauArrays); 00249 00250 // This only does anything if P_mine is either P_first or P_mid. 00251 if (P_mine == P_first) 00252 { 00253 const ordinal_type numCols = R_mine.ncols(); 00254 matrix_type R_other (numCols, numCols); 00255 recv_R (R_other, P_mid); 00256 00257 std::vector< scalar_type > tau (numCols); 00258 // Don't shrink the workspace array; doing so may 00259 // require expensive reallocation every time we send / 00260 // receive data. 00261 resizeWork (numCols); 00262 combine_.factor_pair (numCols, R_mine.get(), R_mine.lda(), 00263 R_other.get(), R_other.lda(), 00264 &tau[0], &work_[0]); 00265 QFactors.push_back (R_other); 00266 tauArrays.push_back (tau); 00267 } 00268 else if (P_mine == P_mid) 00269 send_R (R_mine, P_first); 00270 } 00271 } 00272 00273 void 00274 explicitQBroadcast (matview_type R_mine, 00275 matview_type Q_mine, 00276 matview_type Q_other, // workspace 00277 const rank_type P_mine, 00278 const rank_type P_first, 00279 const rank_type P_last, 00280 const rank_type curpos, 00281 std::vector< matrix_type >& QFactors, 00282 std::vector< std::vector< scalar_type > >& tauArrays) 00283 { 00284 if (P_last < P_first) 00285 { 00286 std::ostringstream os; 00287 os << "Programming error in explicitQBroadcast() recursion: interval" 00288 " [P_first, P_last] is invalid: P_first = " << P_first 00289 << ", P_last = " << P_last << "."; 00290 throw std::logic_error (os.str()); 00291 } 00292 else if (P_mine < P_first || P_mine > P_last) 00293 { 00294 std::ostringstream os; 00295 os << "Programming error in explicitQBroadcast() recursion: P_mine " 00296 "(= " << P_mine << ") is not in current process rank interval " 00297 << "[P_first = " << P_first << ", P_last = " << P_last << "]"; 00298 throw std::logic_error (os.str()); 00299 } 00300 else if (P_last == P_first) 00301 return; // skip singleton intervals 00302 else 00303 { 00304 // Adding 1 and integer division works like "ceiling." 00305 const rank_type P_mid = (P_first + P_last + 1) / 2; 00306 rank_type newpos = curpos; 00307 if (P_mine == P_first) 00308 { 00309 if (curpos < 0) 00310 { 00311 std::ostringstream os; 00312 os << "Programming error: On the current P_first (= " 00313 << P_first << ") proc: curpos (= " << curpos << ") < 0"; 00314 throw std::logic_error (os.str()); 00315 } 00316 // Q_impl, tau: implicitly stored local Q factor. 00317 matrix_type& Q_impl = QFactors[curpos]; 00318 std::vector< scalar_type >& tau = tauArrays[curpos]; 00319 00320 // Apply implicitly stored local Q factor to 00321 // [Q_mine; 00322 // Q_other] 00323 // where Q_other = zeros(Q_mine.nrows(), Q_mine.ncols()). 00324 // Overwrite both Q_mine and Q_other with the result. 00325 Q_other.fill (scalar_type (0)); 00326 combine_.apply_pair (ApplyType::NoTranspose, 00327 Q_mine.ncols(), Q_impl.ncols(), 00328 Q_impl.get(), Q_impl.lda(), &tau[0], 00329 Q_mine.get(), Q_mine.lda(), 00330 Q_other.get(), Q_other.lda(), &work_[0]); 00331 // Send the resulting Q_other, and the final R factor, to P_mid. 00332 send_Q_R (Q_other, R_mine, P_mid); 00333 newpos = curpos - 1; 00334 } 00335 else if (P_mine == P_mid) 00336 // P_first computed my explicit Q factor component. 00337 // Receive it, and the final R factor, from P_first. 00338 recv_Q_R (Q_mine, R_mine, P_first); 00339 00340 if (P_mine < P_mid) // Interval [P_first, P_mid-1] 00341 explicitQBroadcast (R_mine, Q_mine, Q_other, 00342 P_mine, P_first, P_mid - 1, 00343 newpos, QFactors, tauArrays); 00344 else // Interval [P_mid, P_last] 00345 explicitQBroadcast (R_mine, Q_mine, Q_other, 00346 P_mine, P_mid, P_last, 00347 newpos, QFactors, tauArrays); 00348 } 00349 } 00350 00351 template< class ConstMatrixType1, class ConstMatrixType2 > 00352 void 00353 send_Q_R (const ConstMatrixType1& Q, 00354 const ConstMatrixType2& R, 00355 const rank_type destProc) 00356 { 00357 StatTimeMonitor bcastCommMonitor (*bcastCommTime_, bcastCommStats_); 00358 00359 const ordinal_type R_numCols = R.ncols(); 00360 const ordinal_type Q_size = Q.nrows() * Q.ncols(); 00361 const ordinal_type R_size = (R_numCols * (R_numCols + 1)) / 2; 00362 const ordinal_type numElts = Q_size + R_size; 00363 00364 // Don't shrink the workspace array; doing so would still be 00365 // correct, but may require reallocation of data when it needs 00366 // to grow again. 00367 resizeWork (numElts); 00368 00369 // Pack the Q data into the workspace array. 00370 matview_type Q_contig (Q.nrows(), Q.ncols(), &work_[0], Q.nrows()); 00371 Q_contig.copy (Q); 00372 // Pack the R data into the workspace array. 00373 pack_R (R, &work_[Q_size]); 00374 messenger_->send (&work_[0], numElts, destProc, 0); 00375 } 00376 00377 template< class MatrixType1, class MatrixType2 > 00378 void 00379 recv_Q_R (MatrixType1& Q, 00380 MatrixType2& R, 00381 const rank_type srcProc) 00382 { 00383 StatTimeMonitor bcastCommMonitor (*bcastCommTime_, bcastCommStats_); 00384 00385 const ordinal_type R_numCols = R.ncols(); 00386 const ordinal_type Q_size = Q.nrows() * Q.ncols(); 00387 const ordinal_type R_size = (R_numCols * (R_numCols + 1)) / 2; 00388 const ordinal_type numElts = Q_size + R_size; 00389 00390 // Don't shrink the workspace array; doing so would still be 00391 // correct, but may require reallocation of data when it needs 00392 // to grow again. 00393 resizeWork (numElts); 00394 00395 messenger_->recv (&work_[0], numElts, srcProc, 0); 00396 00397 // Unpack the C data from the workspace array. 00398 Q.copy (matview_type (Q.nrows(), Q.ncols(), &work_[0], Q.nrows())); 00399 // Unpack the R data from the workspace array. 00400 unpack_R (R, &work_[Q_size]); 00401 } 00402 00403 template< class ConstMatrixType > 00404 void 00405 send_R (const ConstMatrixType& R, const rank_type destProc) 00406 { 00407 StatTimeMonitor reduceCommMonitor (*reduceCommTime_, reduceCommStats_); 00408 00409 const ordinal_type numCols = R.ncols(); 00410 const ordinal_type numElts = (numCols * (numCols+1)) / 2; 00411 00412 // Don't shrink the workspace array; doing so would still be 00413 // correct, but may require reallocation of data when it needs 00414 // to grow again. 00415 resizeWork (numElts); 00416 // Pack the R data into the workspace array. 00417 pack_R (R, &work_[0]); 00418 messenger_->send (&work_[0], numElts, destProc, 0); 00419 } 00420 00421 template< class MatrixType > 00422 void 00423 recv_R (MatrixType& R, const rank_type srcProc) 00424 { 00425 StatTimeMonitor reduceCommMonitor (*reduceCommTime_, reduceCommStats_); 00426 00427 const ordinal_type numCols = R.ncols(); 00428 const ordinal_type numElts = (numCols * (numCols+1)) / 2; 00429 00430 // Don't shrink the workspace array; doing so would still be 00431 // correct, but may require reallocation of data when it needs 00432 // to grow again. 00433 resizeWork (numElts); 00434 messenger_->recv (&work_[0], numElts, srcProc, 0); 00435 // Unpack the R data from the workspace array. 00436 unpack_R (R, &work_[0]); 00437 } 00438 00439 template< class MatrixType > 00440 static void 00441 unpack_R (MatrixType& R, const scalar_type buf[]) 00442 { 00443 ordinal_type curpos = 0; 00444 for (ordinal_type j = 0; j < R.ncols(); ++j) 00445 { 00446 scalar_type* const R_j = &R(0, j); 00447 for (ordinal_type i = 0; i <= j; ++i) 00448 R_j[i] = buf[curpos++]; 00449 } 00450 } 00451 00452 template< class ConstMatrixType > 00453 static void 00454 pack_R (const ConstMatrixType& R, scalar_type buf[]) 00455 { 00456 ordinal_type curpos = 0; 00457 for (ordinal_type j = 0; j < R.ncols(); ++j) 00458 { 00459 const scalar_type* const R_j = &R(0, j); 00460 for (ordinal_type i = 0; i <= j; ++i) 00461 buf[curpos++] = R_j[i]; 00462 } 00463 } 00464 00465 void 00466 resizeWork (const ordinal_type numElts) 00467 { 00468 typedef typename std::vector< scalar_type >::size_type vec_size_type; 00469 work_.resize (std::max (work_.size(), static_cast< vec_size_type >(numElts))); 00470 } 00471 00472 private: 00473 combine_type combine_; 00474 Teuchos::RCP< MessengerBase< scalar_type > > messenger_; 00475 std::vector< scalar_type > work_; 00476 00477 // Timers for various phases of the factorization. Time is 00478 // cumulative over all calls of factorExplicit(). 00479 Teuchos::RCP< Teuchos::Time > totalTime_; 00480 Teuchos::RCP< Teuchos::Time > reduceCommTime_; 00481 Teuchos::RCP< Teuchos::Time > reduceTime_; 00482 Teuchos::RCP< Teuchos::Time > bcastCommTime_; 00483 Teuchos::RCP< Teuchos::Time > bcastTime_; 00484 00485 TimeStats totalStats_, reduceCommStats_, reduceStats_, bcastCommStats_, bcastStats_; 00486 }; 00487 00488 } // namespace TSQR 00489 00490 #endif // __TSQR_DistTsqrRB_hpp
1.7.4