00001 #include "TSFLAPACKGeneralMatrix.hpp"
00002 #include "Teuchos_LAPACK.hpp"
00003
00004 using namespace TSFExtended;
00005 using namespace Teuchos;
00006
00007 LAPACKGeneralMatrix::LAPACKGeneralMatrix()
00008 : nRows_(0),
00009 nCols_(0),
00010 data_(),
00011 iPiv_(),
00012 isFactored_(false),
00013 factorData_()
00014 {}
00015
00016 LAPACKGeneralMatrix::LAPACKGeneralMatrix(int nRows, int nCols)
00017 : nRows_(nRows),
00018 nCols_(nCols),
00019 data_(nRows*nCols),
00020 iPiv_(),
00021 isFactored_(false),
00022 factorData_()
00023 {}
00024
00025
00026 void LAPACKGeneralMatrix::apply(const DenseSerialVector& in,
00027 DenseSerialVector& out) const
00028 {
00029 mvMult(false, in, out);
00030 }
00031
00032 void LAPACKGeneralMatrix::applyAdjoint(const DenseSerialVector& in,
00033 DenseSerialVector& out) const
00034 {
00035 mvMult(true, in, out);
00036 }
00037
00038 void LAPACKGeneralMatrix::applyInverse(const DenseSerialVector& in,
00039 DenseSerialVector& out) const
00040 {
00041 solve(false, in, out);
00042 }
00043
00044 void LAPACKGeneralMatrix::applyInverseAdjoint(const DenseSerialVector& in,
00045 DenseSerialVector& out) const
00046 {
00047 solve(true, in, out);
00048 }
00049
00050
00051
00052
00053 void LAPACKGeneralMatrix::mvMult(bool transpose, const DenseSerialVector& in,
00054 DenseSerialVector& out) const
00055 {
00056 const double* inPtr = &(in[0]);
00057 double* outPtr = &(out[0]);
00058
00059
00060 ETransp transFlag=NO_TRANS;
00061 if (transpose) transFlag=TRANS;
00062
00063
00064 DenseSerialVector::blasObject().GEMV(transFlag, nRows_, nCols_, 1.0,
00065 &(data_[0]),
00066 nRows_, inPtr, 1, 0.0,
00067 outPtr, 1);
00068 }
00069
00070
00071 void LAPACKGeneralMatrix::solve(bool transpose, const DenseSerialVector& in,
00072 DenseSerialVector& out) const
00073 {
00074 int info = 0;
00075
00076 const double* inPtr = &(in[0]);
00077 double* outPtr = &(out[0]);
00078
00079
00080
00081
00082 DenseSerialVector::blasObject().COPY(nRows_, inPtr, 1, outPtr, 1);
00083
00084
00085 if (!isFactored_)
00086 {
00087 factor();
00088 isFactored_ = true;
00089 }
00090 double* dataPtr = const_cast<double*>(&(factorData_[0]));
00091
00092
00093 char transFlag='N';
00094 if (transpose) transFlag='T';;
00095
00096
00097 int* pivPtr = const_cast<int*>(&(iPiv_[0]));
00098 LAPACK<int, double> lapackObj;
00099 lapackObj.GETRS(transFlag, nRows_, 1, dataPtr,
00100 nRows_, pivPtr, outPtr, nRows_, &info);
00101
00102 TEST_FOR_EXCEPTION(info != 0,
00103 std::runtime_error,
00104 "LAPACKGeneralMatrix backsolve failed with error code"
00105 << info);
00106 }
00107
00108
00109
00110
00111
00112
00113
00114
00115 void LAPACKGeneralMatrix::setElement(int i, int j, const double& aij)
00116 {
00117 isFactored_ = false;
00118 data_[nRows_*j + i] = aij;
00119 }
00120
00121 void LAPACKGeneralMatrix::zero()
00122 {
00123 isFactored_ = false;
00124 data_.zero();
00125 }
00126
00127 void LAPACKGeneralMatrix::factor() const
00128 {
00129 int info = 0;
00130
00131 iPiv_.resize(nRows_);
00132
00133 int* pivPtr = const_cast<int*>(&(iPiv_[0]));
00134
00135 factorData_ = data_;
00136
00137 double* dataPtr = const_cast<double*>(&(factorData_[0]));
00138
00139 LAPACK<int, double> lapackObj;
00140
00141 lapackObj.GETRF(nRows_, nCols_, dataPtr, nRows_, pivPtr, &info);
00142
00143
00144 TEST_FOR_EXCEPTION(info != 0,
00145 std::runtime_error,
00146 "LAPACKGeneralMatrix factorization failed with error code"
00147 << info);
00148
00149 isFactored_ = true;
00150 }
00151
00152 void LAPACKGeneralMatrix::print(std::ostream& os) const
00153 {
00154 os << "LAPACK " << nRows_ << "-by-" << nCols_ << " matrix: " << std::endl;
00155 os << "[";
00156 for (int i=0; i<nRows_; i++)
00157 {
00158 os << "[";
00159 for (int j=0; j<nCols_; j++)
00160 {
00161 os << data_[i + nRows_*j];
00162 if (j < nCols_-1) os << ", ";
00163 }
00164 os << "]";
00165 if (i < nRows_-1) os << ", ";
00166 }
00167 os << "]";
00168 }