Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifndef TSFBLOCKTRIANGULARSOLVER_IMPL_HPP
00030 #define TSFBLOCKTRIANGULARSOLVER_IMPL_HPP
00031
00032 #include "SundanceDefs.hpp"
00033 #include "TSFLinearSolverDecl.hpp"
00034 #include "TSFLinearCombinationImpl.hpp"
00035 #include "TSFSimpleZeroOpDecl.hpp"
00036 #include "TSFBlockTriangularSolverDecl.hpp"
00037
00038
00039 #ifndef HAVE_TEUCHOS_EXPLICIT_INSTANTIATION
00040 #include "TSFLinearSolverImpl.hpp"
00041 #include "TSFSimpleZeroOpDecl.hpp"
00042 #endif
00043
00044 namespace TSFExtended
00045 {
00046 using namespace TSFExtendedOps;
00047
00048 template <class Scalar> inline
00049 BlockTriangularSolver<Scalar>
00050 ::BlockTriangularSolver(const LinearSolver<Scalar>& solver)
00051 : LinearSolverBase<Scalar>(ParameterList()), solvers_(tuple(solver)) {;}
00052
00053 template <class Scalar> inline
00054 BlockTriangularSolver<Scalar>
00055 ::BlockTriangularSolver(const Array<LinearSolver<Scalar> >& solvers)
00056 : LinearSolverBase<Scalar>(ParameterList()), solvers_(solvers) {;}
00057
00058 template <class Scalar> inline
00059 SolverState<Scalar> BlockTriangularSolver<Scalar>
00060 ::solve(const LinearOperator<Scalar>& op,
00061 const Vector<Scalar>& rhs,
00062 Vector<Scalar>& soln) const
00063 {
00064 int nRows = op.numBlockRows();
00065 int nCols = op.numBlockCols();
00066
00067 soln = op.domain().createMember();
00068
00069
00070 TEST_FOR_EXCEPTION(nRows != rhs.space().numBlocks(), std::runtime_error,
00071 "number of rows in operator " << op
00072 << " not equal to number of blocks on RHS "
00073 << rhs);
00074
00075 TEST_FOR_EXCEPTION(nRows != nCols, std::runtime_error,
00076 "nonsquare block structure in block triangular "
00077 "solver: nRows=" << nRows << " nCols=" << nCols);
00078
00079 bool isUpper = false;
00080 bool isLower = false;
00081
00082 for (int r=0; r<nRows; r++)
00083 {
00084 for (int c=0; c<nCols; c++)
00085 {
00086 if (op.getBlock(r,c).ptr().get() == 0 ||
00087 dynamic_cast<const SimpleZeroOp<Scalar>* >(op.getBlock(r,c).ptr().get()))
00088 {
00089 TEST_FOR_EXCEPTION(r==c, std::runtime_error,
00090 "zero diagonal block (" << r << ", " << c
00091 << " detected in block "
00092 "triangular solver. Operator is " << op);
00093 continue;
00094 }
00095 else
00096 {
00097 if (r < c) isUpper = true;
00098 if (c < r) isLower = true;
00099 }
00100 }
00101 }
00102
00103 TEST_FOR_EXCEPTION(isUpper && isLower, std::runtime_error,
00104 "block triangular solver detected non-triangular operator "
00105 << op);
00106
00107 bool oneSolverFitsAll = false;
00108 if ((int) solvers_.size() == 1 && nRows != 1)
00109 {
00110 oneSolverFitsAll = true;
00111 }
00112
00113 for (int i=0; i<nRows; i++)
00114 {
00115 int r = i;
00116 if (isUpper) r = nRows - 1 - i;
00117 Vector<Scalar> rhs_r = rhs.getBlock(r);
00118 for (int j=0; j<i; j++)
00119 {
00120 int c = j;
00121 if (isUpper) c = nCols - 1 - j;
00122 if (op.getBlock(r,c).ptr().get() != 0)
00123 {
00124 rhs_r = rhs_r - op.getBlock(r,c) * soln.getBlock(c);
00125 }
00126 }
00127
00128 SolverState<Scalar> state;
00129 Vector<Scalar> soln_r;
00130 if (oneSolverFitsAll)
00131 {
00132 state = solvers_[0].solve(op.getBlock(r,r), rhs_r, soln_r);
00133 }
00134 else
00135 {
00136 state = solvers_[r].solve(op.getBlock(r,r), rhs_r, soln_r);
00137 }
00138 if (nRows > 1) soln.setBlock(r, soln_r);
00139 else soln = soln_r;
00140 if (state.finalState() != SolveConverged)
00141 {
00142 return state;
00143 }
00144 }
00145
00146 return SolverState<Scalar>(SolveConverged, "block solves converged",
00147 0, ScalarTraits<Scalar>::zero());
00148 }
00149
00150 }
00151
00152 #endif