Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifndef TSFGMRESSOLVER_HPP
00030 #define TSFGMRESSOLVER_HPP
00031
00032 #include "SundanceDefs.hpp"
00033 #include "TSFKrylovSolver.hpp"
00034 #include "SundanceHandleable.hpp"
00035 #include "SundancePrintable.hpp"
00036 #include "Teuchos_Describable.hpp"
00037 #include "TSFLinearCombinationDecl.hpp"
00038
00039 namespace TSFExtended
00040 {
00041 using namespace Teuchos;
00042
00043
00044
00045
00046 template <class Scalar>
00047 class GMRESSolver : public KrylovSolver<Scalar>,
00048 public Sundance::Handleable<LinearSolverBase<Scalar> >,
00049 public Printable,
00050 public Describable
00051 {
00052 public:
00053
00054 GMRESSolver(const ParameterList& params = ParameterList())
00055 : KrylovSolver<Scalar>(params) {;}
00056
00057 GMRESSolver(const ParameterList& params,
00058 const PreconditionerFactory<Scalar>& precond)
00059 : KrylovSolver<Scalar>(params, precond) {;}
00060
00061
00062 virtual ~GMRESSolver(){;}
00063
00064
00065
00066
00067 void print(std::ostream& os) const
00068 {
00069 os << description() << "[" << std::endl;
00070 os << this->parameters() << std::endl;
00071 os << "]" << std::endl;
00072 }
00073
00074
00075
00076 int getKSpace() const {return this->parameters().template get<int>("Restart");}
00077
00078
00079
00080
00081 std::string description() const {return "GMRESSolver";}
00082
00083
00084
00085
00086
00087 virtual RCP<LinearSolverBase<Scalar> > getRcp()
00088 {return rcp(this);}
00089
00090
00091 protected:
00092
00093
00094 virtual SolverState<Scalar> solveUnprec(const LinearOperator<Scalar>& A,
00095 const Vector<Scalar>& rhs,
00096 Vector<Scalar>& soln) const ;
00097
00098
00099 };
00100
00101 template <class Scalar> inline
00102 SolverState<Scalar> GMRESSolver<Scalar>
00103 ::solveUnprec(const LinearOperator<Scalar>& A,
00104 const Vector<Scalar>& b,
00105 Vector<Scalar>& soln) const
00106 {
00107 int myRank = MPIComm::world().getRank();
00108
00109 int maxiters = this->getMaxiters();
00110 int kSpace = getKSpace();
00111 Scalar tol = this->getTol();
00112 int verbosity = this->verb();
00113
00114 if (verbosity > 1)
00115 {
00116 std::cerr << "GMRES solver" << std::endl;
00117 std::cerr << "Max iterations " << maxiters << std::endl;
00118 std::cerr << "Krylov subspace size " << kSpace<< std::endl;
00119 std::cerr << "Convergence tolerance " << tol << std::endl;
00120 }
00121
00122
00123 Scalar normOfB = sqrt(b.dot(b));
00124
00125
00126 if (normOfB < tol)
00127 {
00128 soln = b.space().createMember();
00129 soln.zero();
00130 return SolverState<Scalar>(SolveConverged,
00131 "yippee!!", 0, 0.0);
00132 }
00133
00134 soln = b.copy();
00135
00136
00137 Vector<Scalar> x0 = b.copy();
00138 Vector<Scalar> r0 = b.space().createMember();
00139 Vector<Scalar> vh = b.space().createMember();
00140 Vector<Scalar> tmp = b.space().createMember();
00141 Vector<Scalar> residVec = b.space().createMember();
00142 Vector<Scalar> u = b.space().createMember();
00143 Vector<Scalar> vrf = b.space().createMember();
00144
00145 std::vector<Scalar> h(kSpace+1);
00146 std::vector<Scalar> f(kSpace+1);
00147 std::vector<Scalar> q(kSpace+1);
00148 std::vector<Scalar> mtmp(kSpace+1);
00149 std::vector<Scalar> y(kSpace+1);
00150
00151 std::vector<Vector<Scalar> > V(kSpace + 1);
00152 std::vector<Vector<Scalar> > W(kSpace + 1);
00153 std::vector<std::vector<Scalar> > QT(kSpace + 1);
00154 std::vector<std::vector<Scalar> > R(kSpace + 1);
00155
00156
00157
00158 for (int k = 0; k < V.size(); k++)
00159 {
00160 V[k] = A.domain().createMember();
00161 W[k] = A.domain().createMember();
00162 QT[k].resize(kSpace+1);
00163 R[k].resize(kSpace+1);
00164 }
00165
00166
00167 Scalar relTol = tol * normOfB;
00168 Scalar phibar;
00169 Scalar rt;
00170 Scalar c;
00171 Scalar s;
00172 Scalar temp;
00173
00174 int CONV = 0;
00175
00176
00177 r0 = b - A*x0;
00178 residVec = r0.copy();
00179 Scalar normOfResidVec;
00180 normOfResidVec = residVec.norm2();
00181
00182
00183
00184 int iter = 0;
00185
00186 int j = 0;
00187 while (iter < maxiters)
00188 {
00189 for (int z=0; z<h.size(); z++)
00190 {
00191 h[z]=0.0;
00192 f[z]=0.0;
00193 }
00194 for (int z = 0; z < V.size(); z++)
00195 {
00196 V[z].zero();
00197 W[z].zero();
00198 for (int zz=0; zz<QT[z].size(); zz++)
00199 {QT[z][zz]=0.0; R[z][zz]=0.0;}
00200 }
00201
00202 vh = residVec.copy();
00203 h[0] = vh.norm2();
00204 double newtemp = 1.0 / h[0];
00205 V[0] = newtemp*vh;
00206 QT[0][0] = 1.0;
00207 phibar = h[0];
00208
00209
00210
00211 j = 0;
00212 while ((j < kSpace) & (iter < maxiters))
00213 {
00214 u = A*V[j];
00215 for( int k=0; k<=j; k++)
00216 {
00217 h[k] = V[k].dot(u);
00218 u = u - h[k] * V[k];
00219 }
00220
00221 h[j+1] = u.norm2();
00222 V[j+1] = (1.0 / h[j+1]) * u;
00223
00224 for (int k=0; k<=j; k++)
00225 {
00226 for (int z=0; z<=j; z++)
00227 R[k][j] = R[k][j] + QT[k][z] * h[z];
00228 }
00229
00230 rt = R[j][j];
00231
00232
00233 if (h[j+1] == 0)
00234 {
00235 c = 1.0;
00236 s = 0.0;
00237 }
00238 else if (fabs(h[j+1]) > fabs(rt))
00239 {
00240 temp = rt / h[j+1];
00241
00242 s = 1.0 / sqrt(1.0 + fabs(temp)*fabs(temp));
00243 c = - temp * s;
00244 }
00245 else
00246 {
00247 temp = h[j+1] / rt;
00248
00249 c = 1.0 / sqrt(1.0 + fabs(temp)*fabs(temp));
00250 s = - temp * c;
00251 }
00252
00253 R[j][j] = c * rt - s * h[j+1];
00254
00255
00256 for(int k=0; k<=j; k++)
00257 q[k] = QT[j][k];
00258 for(int k=0; k<=j; k++)
00259 {
00260 QT[j][k] = c * q[k];
00261 QT[j+1][k] = s * q[k];
00262 }
00263 QT[j][j+1] = -s;
00264 QT[j+1][j+1] = c;
00265 f[j] = c * phibar;
00266 phibar = s * phibar;
00267
00268 if (j < kSpace-1)
00269 {
00270 W[j] = V[j].copy();
00271 for(int k=0; k<=j-1; k++)
00272 W[j] = W[j] - R[k][j]*W[k];
00273 W[j] = (1.0 / R[j][j]) * W[j];
00274
00275 soln = soln + f[j]*W[j];
00276 }
00277 else
00278 {
00279 for (int zz=0; zz<mtmp.size(); zz++) mtmp[zz]=0.0;
00280
00281 mtmp[j] = f[j] / R[j][j];
00282 for(int k=j-1; k>=0; k--)
00283 {
00284 mtmp[k] = f[k];
00285 for(int z=k+1; z<=j; z++)
00286 mtmp[k] = mtmp[k] - R[k][z]*mtmp[z];
00287 mtmp[k] = mtmp[k] / R[k][k];
00288 }
00289
00290 vrf.zero();
00291 for(int k=0; k<=j; k++)
00292 vrf = vrf + mtmp[k] * V[k];
00293 soln = x0 + vrf;
00294 }
00295
00296
00297 tmp.zero();
00298 tmp = A*soln;
00299 residVec = b - tmp;
00300 normOfResidVec = residVec.norm2();
00301
00302
00303 if (myRank==0 && verbosity > 1 )
00304 {
00305 std::cerr << "GMRES: iteration=";
00306 std::cerr.width(8);
00307 std::cerr << iter;
00308 std::cerr.width(20);
00309 std::cerr << "scaled resid=" << normOfResidVec/normOfB << std::endl;
00310 }
00311
00312
00313 if (normOfResidVec < relTol)
00314 {
00315 if (j < kSpace-1)
00316 {
00317
00318 for (int zz=0; zz<y.size(); zz++) y[zz]=0.0;
00319
00320
00321 y[j] = f[j] / R[j][j];
00322 for(int k=j-1; k>=0; k--)
00323 {
00324 y[k] = f[k];
00325 for(int z=k+1; z<=j; z++)
00326 y[k] = y[k] - R[k][z]*y[z];
00327 y[k] = y[k] / R[k][k];
00328 }
00329
00330 soln = x0.copy();
00331 for(int k=0; k<=j; k++)
00332 soln = soln + y[k] * V[k];
00333
00334 tmp.zero();
00335 tmp = A*soln;
00336 residVec = b - tmp;
00337 normOfResidVec = residVec.norm2();
00338 }
00339
00340 if (normOfResidVec < relTol)
00341 {
00342
00343 CONV = 1;
00344 if (verbosity > 0 && myRank==0)
00345 {
00346 std::cerr << "GMRES converged (1) in " << iter+1
00347 << " iters: final scaled resid = "
00348 << normOfResidVec/normOfB << std::endl;
00349 }
00350 SolverState<Scalar> rtn(SolveConverged,
00351 "yippee!!", iter+1,
00352 normOfResidVec/normOfB);
00353 return rtn;
00354 }
00355 }
00356
00357 j++;
00358 iter++;
00359 }
00360
00361 if (CONV)
00362 {
00363 if (verbosity > 0 && myRank==0)
00364 {
00365 std::cerr << "GMRES converged (1) in " << iter+1
00366 << " iters: final scaled resid = "
00367 << normOfResidVec/normOfB << std::endl;
00368 }
00369 SolverState<Scalar> rtn(SolveConverged,
00370 "yippee!!", iter+1,
00371 normOfResidVec/normOfB);
00372 return rtn;
00373 }
00374
00375
00376 else if (!CONV & (iter < maxiters))
00377 {
00378
00379 x0 = soln.copy();
00380 tmp.zero();
00381 tmp = A*x0;
00382 residVec = b - tmp;
00383 normOfResidVec = residVec.norm2();
00384
00385 if (verbosity > 1 && myRank==0)
00386 {
00387 std::cerr << "GMRES restarting: current scaled resid = "
00388 << normOfResidVec/normOfB << std::endl;
00389 }
00390 }
00391
00392 }
00393
00394 SolverState<Scalar> rtn(SolveFailedToConverge,
00395 "GMRES failed to converge",
00396 maxiters, normOfResidVec/normOfB);
00397 return rtn;
00398 }
00399 }
00400
00401 #endif