TSFGMRESSolver.hpp
Go to the documentation of this file.
00001 /* @HEADER@ */
00002 /* ***********************************************************************
00003 // 
00004 //           TSFExtended: Trilinos Solver Framework Extended
00005 //                 Copyright (2004) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // **********************************************************************/
00027 /* @HEADER@ */
00028 
00029 #ifndef TSFGMRESSOLVER_HPP
00030 #define TSFGMRESSOLVER_HPP
00031 
00032 #include "SundanceDefs.hpp"
00033 #include "TSFKrylovSolver.hpp"
00034 #include "SundanceHandleable.hpp"
00035 #include "SundancePrintable.hpp"
00036 #include "Teuchos_Describable.hpp"
00037 #include "TSFLinearCombinationDecl.hpp"
00038 
00039 namespace TSFExtended
00040 {
00041   using namespace Teuchos;
00042 
00043   /**
00044    *
00045    */
00046   template <class Scalar>
00047   class GMRESSolver : public KrylovSolver<Scalar>,
00048                       public Sundance::Handleable<LinearSolverBase<Scalar> >,
00049                       public Printable,
00050                       public Describable
00051   {
00052   public:
00053     /** */
00054     GMRESSolver(const ParameterList& params = ParameterList())
00055       : KrylovSolver<Scalar>(params) {;}
00056     /** */
00057     GMRESSolver(const ParameterList& params,
00058                 const PreconditionerFactory<Scalar>& precond)
00059       : KrylovSolver<Scalar>(params, precond) {;}
00060 
00061     /** */
00062     virtual ~GMRESSolver(){;}
00063 
00064     /** \name Printable interface */
00065     //@{
00066     /** Write to a stream  */
00067     void print(std::ostream& os) const 
00068     {
00069       os << description() << "[" << std::endl;
00070       os << this->parameters() << std::endl;
00071       os << "]" << std::endl;
00072     }
00073     //@}
00074 
00075     /** */
00076     int getKSpace() const {return this->parameters().template get<int>("Restart");}
00077     
00078     /** \name Describable interface */
00079     //@{
00080     /** Write a brief description */
00081     std::string description() const {return "GMRESSolver";}
00082     //@}
00083 
00084     /** \name Handleable interface */
00085     //@{
00086     /** Return a ref count pointer to a newly created object */
00087     virtual RCP<LinearSolverBase<Scalar> > getRcp() 
00088     {return rcp(this);}
00089     //@}
00090     
00091   protected:
00092 
00093     /** */
00094     virtual SolverState<Scalar> solveUnprec(const LinearOperator<Scalar>& A,
00095                                             const Vector<Scalar>& rhs,
00096                                             Vector<Scalar>& soln) const ;
00097 
00098     
00099   };
00100 
00101   template <class Scalar> inline
00102   SolverState<Scalar> GMRESSolver<Scalar>
00103   ::solveUnprec(const LinearOperator<Scalar>& A,
00104                 const Vector<Scalar>& b,
00105                 Vector<Scalar>& soln) const
00106   {
00107     int myRank = MPIComm::world().getRank();
00108 
00109     int maxiters = this->getMaxiters();
00110     int kSpace = getKSpace();
00111     Scalar tol = this->getTol();
00112     int verbosity = this->verb();
00113 
00114     if (verbosity > 1)
00115       {
00116         std::cerr << "GMRES solver" << std::endl;
00117         std::cerr << "Max iterations " << maxiters << std::endl;
00118         std::cerr << "Krylov subspace size " << kSpace<< std::endl;
00119         std::cerr << "Convergence tolerance " << tol << std::endl;
00120       }
00121 
00122     // following GMRES from Matlab
00123     Scalar normOfB = sqrt(b.dot(b));
00124 
00125     /* check for trivial case of zero rhs */
00126     if (normOfB < tol) 
00127       {
00128         soln = b.space().createMember();
00129         soln.zero();
00130         return SolverState<Scalar>(SolveConverged, 
00131                                    "yippee!!", 0, 0.0); 
00132       }
00133 
00134     soln = b.copy();
00135 
00136 
00137     Vector<Scalar> x0 = b.copy();
00138     Vector<Scalar> r0 = b.space().createMember(); 
00139     Vector<Scalar> vh = b.space().createMember(); 
00140     Vector<Scalar> tmp = b.space().createMember(); 
00141     Vector<Scalar> residVec = b.space().createMember(); 
00142     Vector<Scalar> u = b.space().createMember(); 
00143     Vector<Scalar> vrf = b.space().createMember(); 
00144 
00145     std::vector<Scalar> h(kSpace+1);
00146     std::vector<Scalar> f(kSpace+1);
00147     std::vector<Scalar> q(kSpace+1);
00148     std::vector<Scalar> mtmp(kSpace+1);
00149     std::vector<Scalar> y(kSpace+1);
00150 
00151     std::vector<Vector<Scalar> > V(kSpace + 1);
00152     std::vector<Vector<Scalar> > W(kSpace + 1);
00153     std::vector<std::vector<Scalar> > QT(kSpace + 1);
00154     std::vector<std::vector<Scalar> > R(kSpace + 1);
00155 
00156 
00157 
00158     for (int k = 0; k < V.size(); k++) 
00159       {   
00160         V[k] = A.domain().createMember(); // V = n x (m+1)
00161         W[k] = A.domain().createMember(); // W = n x (m+1)
00162         QT[k].resize(kSpace+1);
00163         R[k].resize(kSpace+1); // R = (m+1)x(m+1)
00164       }
00165 
00166 
00167     Scalar relTol = tol * normOfB; // relative tolerance
00168     Scalar phibar;
00169     Scalar rt;
00170     Scalar c;
00171     Scalar s;
00172     Scalar temp;
00173 
00174     int CONV = 0; // not converged yet
00175 
00176     // r0 =  b - A*x0;
00177     r0 = b - A*x0;
00178     residVec = r0.copy();
00179     Scalar normOfResidVec;
00180     normOfResidVec = residVec.norm2();
00181   
00182     // Outer loop i = 1 : maxIters unless convergence (or failure)
00183     //  for (int i=0; i<maxIters_; i++)
00184     int iter = 0;
00185     //    int i = 0;
00186     int j = 0;
00187     while (iter < maxiters)
00188       {
00189         for (int z=0; z<h.size(); z++) 
00190           {
00191             h[z]=0.0;
00192             f[z]=0.0;
00193           }
00194         for (int z = 0; z < V.size(); z++) 
00195           {
00196             V[z].zero();
00197             W[z].zero();
00198             for (int zz=0; zz<QT[z].size(); zz++) 
00199               {QT[z][zz]=0.0; R[z][zz]=0.0;}
00200           }
00201 
00202         vh = residVec.copy();
00203         h[0] = vh.norm2();
00204         double newtemp = 1.0 / h[0];
00205         V[0] = newtemp*vh;
00206         QT[0][0] = 1.0;
00207         phibar = h[0];
00208 
00209         // inner loop from 0:restart
00210         // for(int j=0; j<kSpace; j++)
00211         j = 0;
00212         while ((j < kSpace) & (iter < maxiters))
00213           {
00214             u = A*V[j];
00215             for( int k=0; k<=j; k++)
00216               {
00217                 h[k] = V[k].dot(u);
00218                 u = u - h[k] * V[k];
00219               }
00220           
00221             h[j+1] = u.norm2();
00222             V[j+1] = (1.0 / h[j+1]) * u;
00223 
00224             for (int k=0; k<=j; k++)
00225               {
00226                 for (int z=0; z<=j; z++) 
00227                   R[k][j] = R[k][j] + QT[k][z] * h[z];
00228               }
00229 
00230             rt = R[j][j];
00231 
00232             // find cos(theta) and sin(theta) of Givens rotation
00233             if (h[j+1] == 0)
00234               {
00235                 c = 1.0; // theta = 0 
00236                 s = 0.0;
00237               }
00238             else if (fabs(h[j+1]) > fabs(rt))
00239               {
00240                 temp = rt / h[j+1];
00241                 // pi/4 < theta < 3pi/4
00242                 s = 1.0 / sqrt(1.0 + fabs(temp)*fabs(temp)); 
00243                 c = - temp * s;
00244               }
00245             else
00246               {
00247                 temp = h[j+1] / rt;
00248                 // -pi/4 <= theta < 0 < theta <= pi/4
00249                 c = 1.0 / sqrt(1.0 + fabs(temp)*fabs(temp)); 
00250                 s = - temp * c;
00251               }
00252 
00253             R[j][j] = c * rt - s * h[j+1]; // left out conj on c and s
00254 
00255 
00256             for(int k=0; k<=j; k++)
00257               q[k] = QT[j][k];
00258             for(int k=0; k<=j; k++)
00259               {
00260                 QT[j][k] = c * q[k];
00261                 QT[j+1][k] = s * q[k];
00262               }
00263             QT[j][j+1] = -s;
00264             QT[j+1][j+1] =  c;
00265             f[j] = c * phibar;
00266             phibar = s * phibar;
00267 
00268             if (j < kSpace-1)
00269               {
00270                 W[j] = V[j].copy();
00271                 for(int k=0; k<=j-1; k++)
00272                   W[j] = W[j] - R[k][j]*W[k];
00273                 W[j] = (1.0 / R[j][j]) * W[j];
00274 
00275                 soln = soln + f[j]*W[j];
00276               }
00277             else
00278               {
00279                 for (int zz=0; zz<mtmp.size(); zz++) mtmp[zz]=0.0;
00280                 // back solve to get tmp vector to form vrf
00281                 mtmp[j] = f[j] / R[j][j];
00282                 for(int k=j-1; k>=0; k--)
00283                   {
00284                     mtmp[k] = f[k];
00285                     for(int z=k+1; z<=j; z++)
00286                       mtmp[k] = mtmp[k] - R[k][z]*mtmp[z];
00287                     mtmp[k] = mtmp[k] / R[k][k];
00288                   }
00289               
00290                 vrf.zero();
00291                 for(int k=0; k<=j; k++)
00292                   vrf = vrf + mtmp[k] * V[k];
00293                 soln = x0 + vrf;
00294               }
00295           
00296             // update current resid norm
00297             tmp.zero();
00298             tmp = A*soln;
00299             residVec = b - tmp;
00300             normOfResidVec = residVec.norm2();
00301             
00302 
00303             if (myRank==0 && verbosity > 1 ) 
00304               {
00305                 std::cerr << "GMRES: iteration=";
00306                 std::cerr.width(8);
00307                 std::cerr << iter;
00308                 std::cerr.width(20);
00309                 std::cerr << "scaled resid=" << normOfResidVec/normOfB << std::endl;
00310               }
00311 
00312             // check for convergence
00313             if (normOfResidVec < relTol)
00314               {
00315                 if (j < kSpace-1)
00316                   {
00317                     // compute more accurate soln to test convergence
00318                     for (int zz=0; zz<y.size(); zz++) y[zz]=0.0;
00319 
00320                     // back solve to get y(0:j) = R(0:j,0:j) \ f(0:j);
00321                     y[j] = f[j] / R[j][j];
00322                     for(int k=j-1; k>=0; k--)
00323                       {
00324                         y[k] = f[k];
00325                         for(int z=k+1; z<=j; z++)
00326                           y[k] = y[k] - R[k][z]*y[z];
00327                         y[k] = y[k] / R[k][k];
00328                       }
00329                     // soln = x0 + V(:,0:j) * y(0:j)
00330                     soln = x0.copy();
00331                     for(int k=0; k<=j; k++)
00332                       soln = soln + y[k] * V[k];
00333                   
00334                     tmp.zero();
00335                     tmp = A*soln;
00336                     residVec = b - tmp;
00337                     normOfResidVec = residVec.norm2();
00338                   }
00339                 // test for convergence
00340                 if (normOfResidVec < relTol)
00341                   {
00342                     // we're done
00343                     CONV = 1; // converged
00344                     if (verbosity > 0 && myRank==0)
00345                       {
00346                         std::cerr << "GMRES converged (1) in " << iter+1
00347                              << " iters: final scaled resid = " 
00348                              <<  normOfResidVec/normOfB << std::endl;
00349                       }
00350                     SolverState<Scalar> rtn(SolveConverged, 
00351                                             "yippee!!", iter+1, 
00352                                             normOfResidVec/normOfB);
00353                     return rtn;
00354                   }
00355               }
00356             
00357             j++;
00358             iter++;
00359           } // end inner loop
00360       
00361         if (CONV)
00362           {
00363             if (verbosity > 0 && myRank==0)
00364               {
00365                 std::cerr << "GMRES converged (1) in " << iter+1
00366                      << " iters: final scaled resid = " 
00367                      <<  normOfResidVec/normOfB << std::endl;
00368               }
00369             SolverState<Scalar> rtn(SolveConverged, 
00370                                     "yippee!!", iter+1, 
00371                                     normOfResidVec/normOfB);
00372             return rtn;
00373           }
00374 
00375 
00376         else if (!CONV & (iter < maxiters))
00377           {
00378             // not converged yet; update x0 and resid and restart
00379             x0 = soln.copy();
00380             tmp.zero();
00381             tmp = A*x0;
00382             residVec = b - tmp;
00383             normOfResidVec = residVec.norm2();
00384             
00385             if (verbosity > 1 && myRank==0)
00386               {
00387                 std::cerr << "GMRES restarting: current scaled resid = "
00388                      << normOfResidVec/normOfB << std::endl;
00389               }
00390           }
00391       
00392       } // end outer loop
00393     
00394     SolverState<Scalar> rtn(SolveFailedToConverge, 
00395                             "GMRES failed to converge", 
00396                             maxiters, normOfResidVec/normOfB);
00397     return rtn;
00398   }
00399 }
00400 
00401 #endif

Site Contact