TSFLAPACKGeneralMatrix.cpp
Go to the documentation of this file.
00001 #include "TSFLAPACKGeneralMatrix.hpp"
00002 #include "Teuchos_LAPACK.hpp"
00003 
00004 using namespace TSFExtended;
00005 using namespace Teuchos;
00006 
00007 LAPACKGeneralMatrix::LAPACKGeneralMatrix()
00008   : nRows_(0),
00009     nCols_(0),
00010     data_(),
00011     iPiv_(),
00012     isFactored_(false),
00013     factorData_()
00014 {}
00015 
00016 LAPACKGeneralMatrix::LAPACKGeneralMatrix(int nRows, int nCols)
00017   : nRows_(nRows),
00018     nCols_(nCols),
00019     data_(nRows*nCols),
00020     iPiv_(),
00021     isFactored_(false),
00022     factorData_()
00023 {}
00024 
00025 
00026 void LAPACKGeneralMatrix::apply(const DenseSerialVector& in,
00027                                 DenseSerialVector& out) const
00028 {
00029   mvMult(false, in, out);
00030 }
00031 
00032 void LAPACKGeneralMatrix::applyAdjoint(const DenseSerialVector& in,
00033                                        DenseSerialVector& out) const
00034 {
00035   mvMult(true, in, out);
00036 }
00037 
00038 void LAPACKGeneralMatrix::applyInverse(const DenseSerialVector& in,
00039                                        DenseSerialVector& out) const
00040 {
00041   solve(false, in, out);
00042 }
00043 
00044 void LAPACKGeneralMatrix::applyInverseAdjoint(const DenseSerialVector& in,
00045                                               DenseSerialVector& out) const
00046 {
00047   solve(true, in, out);
00048 }
00049 
00050 
00051 
00052 
00053 void LAPACKGeneralMatrix::mvMult(bool transpose, const DenseSerialVector& in,
00054                                  DenseSerialVector& out) const
00055 {
00056   const double* inPtr = &(in[0]);
00057   double* outPtr = &(out[0]);
00058   
00059   /* set the LAPACK transpose flag = "N" for no transpose, "T" for transpose */
00060   ETransp transFlag=NO_TRANS;
00061   if (transpose) transFlag=TRANS;
00062   
00063   // RAB & ADP : 7/10/2002 : We have fixed this!
00064   DenseSerialVector::blasObject().GEMV(transFlag, nRows_, nCols_, 1.0, 
00065                                         &(data_[0]),
00066                                         nRows_, inPtr, 1, 0.0, 
00067                                         outPtr, 1);
00068 }
00069 
00070 
00071 void LAPACKGeneralMatrix::solve(bool transpose, const DenseSerialVector& in,
00072                                 DenseSerialVector& out) const
00073 {
00074   int info = 0;
00075 
00076   const double* inPtr = &(in[0]);
00077   double* outPtr = &(out[0]);
00078 
00079   /* LAPACK overwrites the input vector argument. We copy the input 
00080    * vector into the output vector, and then pass the output vector
00081    * to the backsolve routine. */
00082   DenseSerialVector::blasObject().COPY(nRows_, inPtr, 1, outPtr, 1);
00083 
00084   /* factor if we haven't already done so */
00085   if (!isFactored_)
00086     {
00087       factor();
00088       isFactored_ = true;
00089     }
00090   double* dataPtr = const_cast<double*>(&(factorData_[0]));
00091 
00092   /* set the LAPACK transpose flag = "N" for no transpose, "T" for transpose */
00093   char transFlag='N';
00094   if (transpose) transFlag='T';;
00095 
00096   /* backsolve */
00097   int* pivPtr = const_cast<int*>(&(iPiv_[0]));
00098   LAPACK<int, double> lapackObj;
00099   lapackObj.GETRS(transFlag, nRows_, 1, dataPtr,
00100                   nRows_, pivPtr, outPtr, nRows_, &info);
00101 
00102   TEST_FOR_EXCEPTION(info != 0,
00103                      std::runtime_error,
00104                      "LAPACKGeneralMatrix backsolve failed with error code"
00105                      << info);
00106 }
00107 
00108 
00109 
00110 
00111 
00112 
00113 
00114 
00115 void LAPACKGeneralMatrix::setElement(int i, int j, const double& aij)
00116 {
00117   isFactored_ = false;
00118   data_[nRows_*j + i] = aij;
00119 }
00120 
00121 void LAPACKGeneralMatrix::zero()
00122 {
00123   isFactored_ = false;
00124   data_.zero();
00125 }
00126 
00127 void LAPACKGeneralMatrix::factor() const 
00128 {
00129   int info = 0;
00130 
00131   iPiv_.resize(nRows_);
00132 
00133   int* pivPtr = const_cast<int*>(&(iPiv_[0]));
00134 
00135   factorData_ = data_;
00136 
00137   double* dataPtr = const_cast<double*>(&(factorData_[0]));
00138 
00139   LAPACK<int, double> lapackObj;
00140 
00141   lapackObj.GETRF(nRows_, nCols_, dataPtr, nRows_, pivPtr, &info);
00142 
00143 
00144   TEST_FOR_EXCEPTION(info != 0,
00145                      std::runtime_error,
00146                      "LAPACKGeneralMatrix factorization failed with error code"
00147                      << info);
00148 
00149   isFactored_ = true;
00150 }
00151 
00152 void LAPACKGeneralMatrix::print(std::ostream& os) const
00153 {
00154   os << "LAPACK " << nRows_ << "-by-" << nCols_ << " matrix: " << std::endl;
00155   os << "[";
00156   for (int i=0; i<nRows_; i++)
00157     {
00158       os << "[";
00159       for (int j=0; j<nCols_; j++)
00160         {
00161           os << data_[i + nRows_*j];
00162           if (j < nCols_-1) os << ", ";
00163         }
00164       os << "]";
00165       if (i < nRows_-1) os << ", ";
00166     }
00167   os << "]";
00168 }

Site Contact