TSFDenseSerialMatrix.cpp
Go to the documentation of this file.
00001 /* @HEADER@ */
00002 /* ***********************************************************************
00003 // 
00004 //           TSFExtended: Trilinos Solver Framework Extended
00005 //                 Copyright (2004) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // **********************************************************************/
00027  /* @HEADER@ */
00028 
00029 #include "TSFDenseSerialMatrix.hpp"
00030 #include "TSFDenseSerialMatrixFactory.hpp"
00031 #include "TSFSerialVector.hpp"
00032 #include "TSFVectorSpaceDecl.hpp"  
00033 #include "TSFVectorDecl.hpp"
00034 #include "TSFLinearOperatorDecl.hpp"
00035 #include "Teuchos_BLAS.hpp"
00036 
00037 #ifndef HAVE_TEUCHOS_EXPLICIT_INSTANTIATION
00038 #include "TSFLinearOperatorImpl.hpp"
00039 #include "TSFVectorImpl.hpp"
00040 #endif
00041 
00042 extern "C"
00043 {
00044 void dgesv_(int *n, int *nrhs, double *a, int* lda, 
00045   int *ipiv, double *b, int *ldb, int *info);
00046 
00047 void dgesvd_( char* jobu, char* jobvt, int* m, int* n, double* a,
00048   int* lda, double* s, double* u, int* ldu, double* vt, int* ldvt,
00049   double* work, int* lwork, int* info );
00050 }
00051 using std::max;
00052 using std::min;
00053 
00054 using namespace TSFExtended;
00055 using namespace Teuchos;
00056 using namespace Thyra;
00057 using std::setw;
00058 
00059 DenseSerialMatrix::DenseSerialMatrix(
00060   const RCP<const SerialVectorSpace>& domain,
00061   const RCP<const SerialVectorSpace>& range)
00062   : range_(range),
00063     domain_(domain),
00064     nRows_(range_->dim()),
00065     nCols_(domain_->dim()),
00066     data_(nRows_*nCols_)
00067 {}
00068 
00069 
00070 Teuchos::ETransp thyraTransToTeuchosTrans(const Thyra::EOpTransp M_trans)
00071 {
00072   switch(M_trans)
00073   {
00074     case Thyra::NOTRANS:
00075       return Teuchos::NO_TRANS;
00076     case Thyra::TRANS:
00077       return Teuchos::TRANS;
00078     case Thyra::CONJTRANS:
00079       return Teuchos::CONJ_TRANS;
00080     default:
00081       TEST_FOR_EXCEPT(true);
00082   }
00083   return Teuchos::NO_TRANS; // -Wall
00084 }
00085 
00086 void DenseSerialMatrix::applyOp(
00087   const Thyra::EOpTransp M_trans,
00088   const Vector<double>& in,
00089   Vector<double> out) const
00090 {
00091   const SerialVector* rvIn = SerialVector::getConcrete(in);
00092   SerialVector* rvOut = SerialVector::getConcrete(out);
00093 
00094   Teuchos::BLAS<OrdType, double> blas;
00095   int lda = numRows();
00096   Teuchos::ETransp trans = thyraTransToTeuchosTrans(M_trans);
00097   blas.GEMV(trans, numRows(), numCols(), 1.0, dataPtr(), 
00098     lda, rvIn->dataPtr(), 1, 1.0, rvOut->dataPtr(), 1);
00099 }
00100 
00101 void DenseSerialMatrix::addToRow(int globalRowIndex,
00102   int nElemsToInsert,
00103   const int* globalColumnIndices,
00104   const double* elementValues)
00105 {
00106   int r = globalRowIndex;
00107   for (int k=0; k<nElemsToInsert; k++)
00108   {
00109     int c = globalColumnIndices[k];
00110     double x = elementValues[k];
00111     data_[r + c*numRows()] = x;
00112   }
00113 }
00114 
00115 void DenseSerialMatrix::zero()
00116 {
00117   for (int i=0; i<data_.size(); i++) data_[i] = 0.0;
00118 }
00119 
00120 
00121 void DenseSerialMatrix::print(std::ostream& os) const
00122 {
00123   if (numCols() <= 4)
00124   {
00125     for (int i=0; i<numRows(); i++)
00126     {
00127       for (int j=0; j<numCols(); j++)
00128       {
00129         os << setw(16) << data_[i+numRows()*j];
00130       }
00131       os << std::endl;
00132     }
00133   }
00134   else
00135   {
00136     for (int i=0; i<numRows(); i++)
00137     {
00138       for (int j=0; j<numCols(); j++)
00139       {
00140         os << setw(6) << i << setw(6) << j << setw(20) << data_[i+numRows()*j]
00141            << std::endl;
00142       }
00143     }
00144   }
00145 }
00146 
00147 void DenseSerialMatrix::setRow(int row, const Array<double>& rowVals)
00148 {
00149   TEST_FOR_EXCEPT(rowVals.size() != numCols());
00150   TEST_FOR_EXCEPT(row < 0);
00151   TEST_FOR_EXCEPT(row >= numRows());
00152 
00153   for (int i=0; i<rowVals.size(); i++)
00154   {
00155     data_[row+numRows()*i] = rowVals[i];
00156   }
00157 }
00158 
00159 
00160 namespace TSFExtended
00161 {
00162 
00163 
00164 SolverState<double> denseSolve(const LinearOperator<double>& A,
00165   const Vector<double>& b,
00166   Vector<double>& x)
00167 {
00168   const DenseSerialMatrix* Aptr 
00169     = dynamic_cast<const DenseSerialMatrix*>(A.ptr().get());
00170   TEST_FOR_EXCEPT(Aptr==0);
00171   /* make a working copy, because dgesv will overwrite the matrix */
00172   DenseSerialMatrix tmp = *Aptr;
00173   /* Allocate a vector for the solution */
00174   x = b.copy();
00175   SerialVector* xptr 
00176     = dynamic_cast<SerialVector*>(x.ptr().get());
00177   
00178   int N = Aptr->numRows();
00179   int nRHS = 1;
00180   int LDA = N;
00181   Array<int> iPiv(N);
00182   int LDB = N;
00183   int info = 0;
00184   dgesv_(&N, &nRHS, tmp.dataPtr(), &LDA, &(iPiv[0]), xptr->dataPtr(),
00185     &LDB, &info);
00186 
00187   if (info == 0)
00188   {
00189     return SolverState<double>(SolveConverged, "solve OK",
00190       0, 0.0);
00191   }
00192   else 
00193   {
00194     return SolverState<double>(SolveCrashed, "solve crashed with dgesv info="
00195       + Teuchos::toString(info),
00196       0, 0.0);
00197   }
00198 }
00199 
00200 
00201 void denseSVD(const LinearOperator<double>& A,
00202   LinearOperator<double>& U,  
00203   Vector<double>& Sigma,
00204   LinearOperator<double>& Vt)
00205 {
00206   VectorSpace<double> mSpace = A.range();
00207   RCP<const SerialVectorSpace> rmSpace 
00208     = rcp_dynamic_cast<const SerialVectorSpace>(mSpace.ptr());
00209 
00210   VectorSpace<double> nSpace = A.domain();
00211   RCP<const SerialVectorSpace> rnSpace 
00212     = rcp_dynamic_cast<const SerialVectorSpace>(nSpace.ptr());
00213 
00214   const DenseSerialMatrix* Aptr 
00215     = dynamic_cast<const DenseSerialMatrix*>(A.ptr().get());
00216   TEST_FOR_EXCEPT(Aptr==0);
00217   /* make a working copy, because dgesvd will overwrite the matrix */
00218   DenseSerialMatrix ATmp = *Aptr;
00219 
00220   int M = ATmp.numRows();
00221   int N = ATmp.numCols();
00222   int S = min(M, N);
00223   
00224   VectorSpace<double> sSpace;
00225   if (S==M) sSpace = mSpace;
00226   else sSpace = nSpace;
00227 
00228   RCP<const SerialVectorSpace> rsSpace 
00229     = rcp_dynamic_cast<const SerialVectorSpace>(sSpace.ptr());
00230 
00231   Sigma = sSpace.createMember();
00232   SerialVector* sigPtr
00233     = dynamic_cast<SerialVector*>(Sigma.ptr().get());
00234   TEST_FOR_EXCEPT(sigPtr==0);
00235 
00236   DenseSerialMatrixFactory umf(rsSpace, rmSpace);
00237   DenseSerialMatrixFactory vtmf(rnSpace, rsSpace);
00238   
00239   U = umf.createMatrix();
00240   Vt = vtmf.createMatrix();
00241 
00242   DenseSerialMatrix* UPtr 
00243     = dynamic_cast<DenseSerialMatrix*>(U.ptr().get());
00244   TEST_FOR_EXCEPT(UPtr==0);
00245 
00246   DenseSerialMatrix* VtPtr 
00247     = dynamic_cast<DenseSerialMatrix*>(Vt.ptr().get());
00248   TEST_FOR_EXCEPT(VtPtr==0);
00249   
00250   double* uData = UPtr->dataPtr();
00251   double* vtData = VtPtr->dataPtr();
00252   double* aData = ATmp.dataPtr();
00253   double* sData = sigPtr->dataPtr();
00254 
00255   char jobu = 'S';
00256   char jobvt = 'S';
00257  
00258   int LDA = M;
00259   int LDU = M;
00260   int LDVT = S;
00261 
00262   int LWORK = max(1, max(3*min(M,N)+max(M,N), 5*min(M,N)));
00263   Array<double> work(LWORK);
00264   
00265   int info = 0;
00266 
00267   dgesvd_(&jobu, &jobvt, &M, &N, aData, &LDA, sData, uData, &LDU, 
00268     vtData, &LDVT, &(work[0]), &LWORK, &info);
00269 
00270   TEST_FOR_EXCEPTION(info != 0, std::runtime_error,
00271     "dgesvd failed with error code info=" << info);
00272 
00273   
00274   
00275 }
00276 
00277 }

Site Contact