|
Anasazi Version of the Day
|
00001 // @HEADER 00002 // *********************************************************************** 00003 // 00004 // Anasazi: Block Eigensolvers Package 00005 // Copyright (2010) Sandia Corporation 00006 // 00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive 00008 // license for use of this work by or on behalf of the U.S. Government. 00009 // 00010 // This library is free software; you can redistribute it and/or modify 00011 // it under the terms of the GNU Lesser General Public License as 00012 // published by the Free Software Foundation; either version 2.1 of the 00013 // License, or (at your option) any later version. 00014 // 00015 // This library is distributed in the hope that it will be useful, but 00016 // WITHOUT ANY WARRANTY; without even the implied warranty of 00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00018 // Lesser General Public License for more details. 00019 // 00020 // You should have received a copy of the GNU Lesser General Public 00021 // License along with this library; if not, write to the Free Software 00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00023 // USA 00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 00025 // 00026 // *********************************************************************** 00027 // @HEADER 00028 00029 #ifndef __TSQR_Trilinos_TsqrAdaptor_hpp 00030 #define __TSQR_Trilinos_TsqrAdaptor_hpp 00031 00035 00036 #include "AnasaziConfigDefs.hpp" 00037 #include "Teuchos_SerialDenseMatrix.hpp" 00038 00039 #include "TsqrTypeAdaptor.hpp" 00040 #include "TsqrCommFactory.hpp" 00041 00042 #include "Tsqr_GlobalVerify.hpp" 00043 #include "Tsqr_ScalarTraits.hpp" 00044 00045 #include <stdexcept> 00046 #include <sstream> 00047 00050 00051 namespace TSQR { 00052 namespace Trilinos { 00053 00104 template< class S, class LO, class GO, class MV > 00105 class TsqrAdaptor { 00106 public: 00107 typedef S scalar_type; 00108 typedef LO local_ordinal_type; 00109 typedef GO global_ordinal_type; 00110 typedef MV multivector_type; 00111 00112 typedef typename TSQR::ScalarTraits< scalar_type >::magnitude_type magnitude_type; 00113 00114 typedef TsqrTypeAdaptor< S, LO, GO, MV > type_adaptor; 00115 typedef typename type_adaptor::factory_type factory_type; 00116 00117 typedef typename type_adaptor::node_tsqr_type node_tsqr_type; 00118 typedef typename type_adaptor::node_tsqr_ptr node_tsqr_ptr; 00119 00120 typedef typename type_adaptor::comm_type comm_type; 00121 typedef typename type_adaptor::comm_ptr comm_ptr; 00122 00123 typedef typename type_adaptor::dist_tsqr_type dist_tsqr_type; 00124 typedef typename type_adaptor::dist_tsqr_ptr dist_tsqr_ptr; 00125 00126 typedef typename type_adaptor::tsqr_type tsqr_type; 00127 typedef typename type_adaptor::tsqr_ptr tsqr_ptr; 00128 00129 typedef typename tsqr_type::FactorOutput factor_output_type; 00130 typedef Teuchos::SerialDenseMatrix< LO, S > dense_matrix_type; 00131 typedef Teuchos::RCP< MessengerBase< S > > scalar_messenger_ptr; 00132 typedef Teuchos::RCP< MessengerBase< LO > > ordinal_messenger_ptr; 00133 00134 virtual ~TsqrAdaptor() {} 00135 00172 virtual factor_output_type 00173 factor (multivector_type& A, 00174 dense_matrix_type& R, 00175 const bool contiguousCacheBlocks = false) 00176 { 00177 local_ordinal_type nrowsLocal, ncols, LDA; 00178 fetchDims (A, nrowsLocal, ncols, LDA); 00179 // This is guaranteed to be _correct_ for any Node type, but 00180 // won't necessary be efficient. The desired model is that 00181 // A_local requires no copying. 00182 Teuchos::ArrayRCP< scalar_type > A_local = fetchNonConstView (A); 00183 00184 // Reshape R if necessary. This operation zeros out all the 00185 // entries of R, which is what we want anyway. 00186 if (R.numRows() != ncols || R.numCols() != ncols) 00187 { 00188 if (0 != R.shape (ncols, ncols)) 00189 throw std::runtime_error ("Failed to reshape matrix R"); 00190 } 00191 return pTsqr_->factor (nrowsLocal, ncols, A_local.get(), LDA, 00192 R.values(), R.stride(), contiguousCacheBlocks); 00193 } 00194 00221 virtual void 00222 explicitQ (const multivector_type& Q_in, 00223 const factor_output_type& factorOutput, 00224 multivector_type& Q_out, 00225 const bool contiguousCacheBlocks = false) 00226 { 00227 using Teuchos::ArrayRCP; 00228 00229 local_ordinal_type nrowsLocal, ncols_in, LDQ_in; 00230 fetchDims (Q_in, nrowsLocal, ncols_in, LDQ_in); 00231 local_ordinal_type nrowsLocal_out, ncols_out, LDQ_out; 00232 fetchDims (Q_out, nrowsLocal_out, ncols_out, LDQ_out); 00233 00234 if (nrowsLocal_out != nrowsLocal) 00235 { 00236 std::ostringstream os; 00237 os << "TSQR explicit Q: input Q factor\'s node-local part has a di" 00238 "fferent number of rows (" << nrowsLocal << ") than output Q fac" 00239 "tor\'s node-local part (" << nrowsLocal_out << ")."; 00240 throw std::runtime_error (os.str()); 00241 } 00242 ArrayRCP< const scalar_type > pQin = fetchConstView (Q_in); 00243 ArrayRCP< scalar_type > pQout = fetchNonConstView (Q_out); 00244 pTsqr_->explicit_Q (nrowsLocal, 00245 ncols_in, pQin.get(), LDQ_in, 00246 factorOutput, 00247 ncols_out, pQout.get(), LDQ_out, 00248 contiguousCacheBlocks); 00249 } 00250 00275 local_ordinal_type 00276 revealRank (multivector_type& Q, 00277 dense_matrix_type& R, 00278 const magnitude_type relativeTolerance, 00279 const bool contiguousCacheBlocks = false) const 00280 { 00281 using Teuchos::ArrayRCP; 00282 00283 local_ordinal_type nrowsLocal, ncols, ldqLocal; 00284 fetchDims (Q, nrowsLocal, ncols, ldqLocal); 00285 00286 ArrayRCP< scalar_type > Q_ptr = fetchNonConstView (Q); 00287 return pTsqr_->reveal_rank (nrowsLocal, ncols, 00288 Q_ptr.get(), ldqLocal, 00289 R.values(), R.stride(), 00290 relativeTolerance, 00291 contiguousCacheBlocks); 00292 } 00293 00304 virtual void 00305 cacheBlock (const multivector_type& A_in, 00306 multivector_type& A_out) 00307 { 00308 using Teuchos::ArrayRCP; 00309 00310 local_ordinal_type nrowsLocal, ncols, LDA_in; 00311 fetchDims (A_in, nrowsLocal, ncols, LDA_in); 00312 local_ordinal_type nrowsLocal_out, ncols_out, LDA_out; 00313 fetchDims (A_out, nrowsLocal_out, ncols_out, LDA_out); 00314 00315 if (nrowsLocal_out != nrowsLocal) 00316 { 00317 std::ostringstream os; 00318 os << "TSQR cache block: the input matrix\'s node-local part has a" 00319 " different number of rows (" << nrowsLocal << ") than the outpu" 00320 "t matrix\'s node-local part (" << nrowsLocal_out << ")."; 00321 throw std::runtime_error (os.str()); 00322 } 00323 else if (ncols_out != ncols) 00324 { 00325 std::ostringstream os; 00326 os << "TSQR cache block: the input matrix\'s node-local part has a" 00327 " different number of columns (" << ncols << ") than the output " 00328 "matrix\'s node-local part (" << ncols_out << ")."; 00329 throw std::runtime_error (os.str()); 00330 } 00331 ArrayRCP< const scalar_type > pA_in = fetchConstView (A_in); 00332 ArrayRCP< scalar_type > pA_out = fetchNonConstView (A_out); 00333 pTsqr_->cache_block (nrowsLocal, ncols, pA_out.get(), 00334 pA_in.get(), LDA_in); 00335 } 00336 00342 virtual void 00343 unCacheBlock (const multivector_type& A_in, 00344 multivector_type& A_out) 00345 { 00346 using Teuchos::ArrayRCP; 00347 00348 local_ordinal_type nrowsLocal, ncols, LDA_in; 00349 fetchDims (A_in, nrowsLocal, ncols, LDA_in); 00350 local_ordinal_type nrowsLocal_out, ncols_out, LDA_out; 00351 fetchDims (A_out, nrowsLocal_out, ncols_out, LDA_out); 00352 00353 if (nrowsLocal_out != nrowsLocal) 00354 { 00355 std::ostringstream os; 00356 os << "TSQR un-cache-block: the input matrix\'s node-local part ha" 00357 "s a different number of rows (" << nrowsLocal << ") than the ou" 00358 "tput matrix\'s node-local part (" << nrowsLocal_out << ")."; 00359 throw std::runtime_error (os.str()); 00360 } 00361 else if (ncols_out != ncols) 00362 { 00363 std::ostringstream os; 00364 os << "TSQR cache block: the input matrix\'s node-local part has a" 00365 " different number of columns (" << ncols << ") than the output " 00366 "matrix\'s node-local part (" << ncols_out << ")."; 00367 throw std::runtime_error (os.str()); 00368 } 00369 ArrayRCP< const scalar_type > pA_in = fetchConstView (A_in); 00370 ArrayRCP< scalar_type > pA_out = fetchNonConstView (A_out); 00371 pTsqr_->un_cache_block (nrowsLocal, ncols, pA_out.get(), 00372 LDA_out, pA_in.get()); 00373 } 00374 00377 virtual std::pair< magnitude_type, magnitude_type > 00378 verify (const multivector_type& A, 00379 const multivector_type& Q, 00380 const Teuchos::SerialDenseMatrix< local_ordinal_type, scalar_type >& R) 00381 { 00382 using Teuchos::ArrayRCP; 00383 00384 local_ordinal_type nrowsLocal_A, ncols_A, LDA; 00385 local_ordinal_type nrowsLocal_Q, ncols_Q, LDQ; 00386 fetchDims (A, nrowsLocal_A, ncols_A, LDA); 00387 fetchDims (Q, nrowsLocal_Q, ncols_Q, LDQ); 00388 if (nrowsLocal_A != nrowsLocal_Q) 00389 throw std::runtime_error ("A and Q must have same number of rows"); 00390 else if (ncols_A != ncols_Q) 00391 throw std::runtime_error ("A and Q must have same number of columns"); 00392 else if (ncols_A != R.numCols()) 00393 throw std::runtime_error ("A and R must have same number of columns"); 00394 else if (R.numRows() < R.numCols()) 00395 throw std::runtime_error ("R must have no fewer rows than columns"); 00396 00397 // Const views suffice for verification 00398 ArrayRCP< const scalar_type > A_ptr = fetchConstView (A); 00399 ArrayRCP< const scalar_type > Q_ptr = fetchConstView (Q); 00400 return global_verify (nrowsLocal_A, ncols_A, A_ptr.get(), LDA, 00401 Q_ptr.get(), LDQ, R.values(), R.stride(), 00402 pScalarMessenger_.get()); 00403 } 00404 00405 protected: 00408 void 00409 init (const multivector_type& mv, 00410 const Teuchos::ParameterList& plist) 00411 { 00412 // This is done in a multivector type - dependent way. 00413 fetchMessengers (mv, pScalarMessenger_, pOrdinalMessenger_); 00414 00415 factory_type factory; 00416 // plist and pScalarMessenger_ are inputs. Construct *pTsqr_. 00417 factory.makeTsqr (plist, pScalarMessenger_, pTsqr_); 00418 } 00419 00420 private: 00436 virtual void 00437 fetchDims (const multivector_type& A, 00438 local_ordinal_type& nrowsLocal, 00439 local_ordinal_type& ncols, 00440 local_ordinal_type& LDA) const = 0; 00441 00449 virtual Teuchos::ArrayRCP< scalar_type > 00450 fetchNonConstView (multivector_type& A) const = 0; 00451 00459 virtual Teuchos::ArrayRCP< const scalar_type > 00460 fetchConstView (const multivector_type& A) const = 0; 00461 00464 virtual void 00465 fetchMessengers (const multivector_type& mv, 00466 scalar_messenger_ptr& pScalarMessenger, 00467 ordinal_messenger_ptr& pOrdinalMessenger) const = 0; 00468 00471 scalar_messenger_ptr pScalarMessenger_; 00472 00475 ordinal_messenger_ptr pOrdinalMessenger_; 00476 00479 tsqr_ptr pTsqr_; 00480 }; 00481 00482 } // namespace Trilinos 00483 } // namespace TSQR 00484 00485 #endif // __TSQR_Trilinos_TsqrAdaptor_hpp
1.7.4