TSFBlockTriangularSolver.hpp
Go to the documentation of this file.
00001 /* @HEADER@ */
00002 /* ***********************************************************************
00003 // 
00004 //           TSFExtended: Trilinos Solver Framework Extended
00005 //                 Copyright (2004) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // **********************************************************************/
00027  /* @HEADER@ */
00028 
00029 #ifndef TSFBLOCKTRIANGULARSOLVER_HPP
00030 #define TSFBLOCKTRIANGULARSOLVER_HPP
00031 
00032 #include "SundanceDefs.hpp"
00033 #include "TSFLinearSolverDecl.hpp" 
00034 #include "TSFLinearCombinationDecl.hpp" 
00035 #include "TSFCommonOperatorsDecl.hpp" 
00036 
00037 
00038 namespace TSFExtended
00039 {
00040   /** */
00041   template <class Scalar>
00042   class BlockTriangularSolver : public LinearSolverBase<Scalar>,
00043                                 public Sundance::Handleable<LinearSolverBase<Scalar> >
00044   {
00045   public:
00046     /** */
00047     BlockTriangularSolver(const LinearSolver<Scalar>& solver)
00048       : LinearSolverBase<Scalar>(ParameterList()), solvers_(tuple(solver)) {;}
00049 
00050     /** */
00051     BlockTriangularSolver(const Array<LinearSolver<Scalar> >& solvers)
00052       : LinearSolverBase<Scalar>(ParameterList()), solvers_(solvers) {;}
00053 
00054     /** */
00055     virtual ~BlockTriangularSolver(){;}
00056 
00057     /** */
00058     virtual SolverState<Scalar> solve(const LinearOperator<Scalar>& op,
00059                                       const Vector<Scalar>& rhs,
00060                                       Vector<Scalar>& soln) const ;
00061 
00062     /* */
00063     GET_RCP(LinearSolverBase<Scalar>);
00064   private:
00065     Array<LinearSolver<Scalar> > solvers_;
00066   };
00067 
00068 
00069   template <class Scalar> inline
00070   SolverState<Scalar> BlockTriangularSolver<Scalar>
00071   ::solve(const LinearOperator<Scalar>& op,
00072           const Vector<Scalar>& rhs,
00073           Vector<Scalar>& soln) const
00074   {
00075     int nRows = op.numBlockRows();
00076     int nCols = op.numBlockCols();
00077 
00078     soln = op.domain().createMember();
00079     //    bool converged = false;
00080 
00081     TEST_FOR_EXCEPTION(nRows != rhs.space().numBlocks(), std::runtime_error,
00082                        "number of rows in operator " << op
00083                        << " not equal to number of blocks on RHS "
00084                        << rhs);
00085 
00086     TEST_FOR_EXCEPTION(nRows != nCols, std::runtime_error,
00087                        "nonsquare block structure in block triangular "
00088                        "solver: nRows=" << nRows << " nCols=" << nCols);
00089 
00090     bool isUpper = false;
00091     bool isLower = false;
00092 
00093     for (int r=0; r<nRows; r++)
00094       {
00095         for (int c=0; c<nCols; c++)
00096           {
00097             if (op.getBlock(r,c).ptr().get() == 0 ||
00098                 dynamic_cast<const SimpleZeroOp<Scalar>* >(op.getBlock(r,c).ptr().get()))
00099               {
00100                 TEST_FOR_EXCEPTION(r==c, std::runtime_error,
00101                                    "zero diagonal block (" << r << ", " << c 
00102                                    << " detected in block "
00103                                    "triangular solver. Operator is " << op);
00104                 continue;
00105               }
00106             else
00107               {
00108                 if (r < c) isUpper = true;
00109                 if (c < r) isLower = true;
00110               }
00111           }
00112       }
00113 
00114     TEST_FOR_EXCEPTION(isUpper && isLower, std::runtime_error, 
00115                        "block triangular solver detected non-triangular operator "
00116                        << op);
00117 
00118     bool oneSolverFitsAll = false;
00119     if ((int) solvers_.size() == 1 && nRows != 1) 
00120       {
00121         oneSolverFitsAll = true;
00122       }
00123 
00124     for (int i=0; i<nRows; i++)
00125       {
00126         int r = i;
00127         if (isUpper) r = nRows - 1 - i;
00128         Vector<Scalar> rhs_r = rhs.getBlock(r);
00129         for (int j=0; j<i; j++)
00130           {
00131             int c = j;
00132             if (isUpper) c = nCols - 1 - j;
00133             if (op.getBlock(r,c).ptr().get() != 0)
00134               {
00135                 rhs_r = rhs_r - op.getBlock(r,c) * soln.getBlock(c);
00136               }
00137           }
00138 
00139         SolverState<Scalar> state;
00140         Vector<Scalar> soln_r;
00141         if (oneSolverFitsAll)
00142           {
00143             state = solvers_[0].solve(op.getBlock(r,r), rhs_r, soln_r);
00144           }
00145         else
00146           {
00147             state = solvers_[r].solve(op.getBlock(r,r), rhs_r, soln_r);
00148           }
00149         if (nRows > 1) soln.setBlock(r, soln_r);
00150         else soln = soln_r;
00151         if (state.finalState() != SolveConverged)
00152           {
00153             return state;
00154           }
00155       }
00156 
00157     return SolverState<Scalar>(SolveConverged, "block solves converged",
00158                                0, ScalarTraits<Scalar>::zero());
00159   }
00160   
00161 }
00162 
00163 #endif

Site Contact