Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifndef TSFBLOCKTRIANGULARSOLVER_HPP
00030 #define TSFBLOCKTRIANGULARSOLVER_HPP
00031
00032 #include "SundanceDefs.hpp"
00033 #include "TSFLinearSolverDecl.hpp"
00034 #include "TSFLinearCombinationDecl.hpp"
00035 #include "TSFCommonOperatorsDecl.hpp"
00036
00037
00038 namespace TSFExtended
00039 {
00040
00041 template <class Scalar>
00042 class BlockTriangularSolver : public LinearSolverBase<Scalar>,
00043 public Sundance::Handleable<LinearSolverBase<Scalar> >
00044 {
00045 public:
00046
00047 BlockTriangularSolver(const LinearSolver<Scalar>& solver)
00048 : LinearSolverBase<Scalar>(ParameterList()), solvers_(tuple(solver)) {;}
00049
00050
00051 BlockTriangularSolver(const Array<LinearSolver<Scalar> >& solvers)
00052 : LinearSolverBase<Scalar>(ParameterList()), solvers_(solvers) {;}
00053
00054
00055 virtual ~BlockTriangularSolver(){;}
00056
00057
00058 virtual SolverState<Scalar> solve(const LinearOperator<Scalar>& op,
00059 const Vector<Scalar>& rhs,
00060 Vector<Scalar>& soln) const ;
00061
00062
00063 GET_RCP(LinearSolverBase<Scalar>);
00064 private:
00065 Array<LinearSolver<Scalar> > solvers_;
00066 };
00067
00068
00069 template <class Scalar> inline
00070 SolverState<Scalar> BlockTriangularSolver<Scalar>
00071 ::solve(const LinearOperator<Scalar>& op,
00072 const Vector<Scalar>& rhs,
00073 Vector<Scalar>& soln) const
00074 {
00075 int nRows = op.numBlockRows();
00076 int nCols = op.numBlockCols();
00077
00078 soln = op.domain().createMember();
00079
00080
00081 TEST_FOR_EXCEPTION(nRows != rhs.space().numBlocks(), std::runtime_error,
00082 "number of rows in operator " << op
00083 << " not equal to number of blocks on RHS "
00084 << rhs);
00085
00086 TEST_FOR_EXCEPTION(nRows != nCols, std::runtime_error,
00087 "nonsquare block structure in block triangular "
00088 "solver: nRows=" << nRows << " nCols=" << nCols);
00089
00090 bool isUpper = false;
00091 bool isLower = false;
00092
00093 for (int r=0; r<nRows; r++)
00094 {
00095 for (int c=0; c<nCols; c++)
00096 {
00097 if (op.getBlock(r,c).ptr().get() == 0 ||
00098 dynamic_cast<const SimpleZeroOp<Scalar>* >(op.getBlock(r,c).ptr().get()))
00099 {
00100 TEST_FOR_EXCEPTION(r==c, std::runtime_error,
00101 "zero diagonal block (" << r << ", " << c
00102 << " detected in block "
00103 "triangular solver. Operator is " << op);
00104 continue;
00105 }
00106 else
00107 {
00108 if (r < c) isUpper = true;
00109 if (c < r) isLower = true;
00110 }
00111 }
00112 }
00113
00114 TEST_FOR_EXCEPTION(isUpper && isLower, std::runtime_error,
00115 "block triangular solver detected non-triangular operator "
00116 << op);
00117
00118 bool oneSolverFitsAll = false;
00119 if ((int) solvers_.size() == 1 && nRows != 1)
00120 {
00121 oneSolverFitsAll = true;
00122 }
00123
00124 for (int i=0; i<nRows; i++)
00125 {
00126 int r = i;
00127 if (isUpper) r = nRows - 1 - i;
00128 Vector<Scalar> rhs_r = rhs.getBlock(r);
00129 for (int j=0; j<i; j++)
00130 {
00131 int c = j;
00132 if (isUpper) c = nCols - 1 - j;
00133 if (op.getBlock(r,c).ptr().get() != 0)
00134 {
00135 rhs_r = rhs_r - op.getBlock(r,c) * soln.getBlock(c);
00136 }
00137 }
00138
00139 SolverState<Scalar> state;
00140 Vector<Scalar> soln_r;
00141 if (oneSolverFitsAll)
00142 {
00143 state = solvers_[0].solve(op.getBlock(r,r), rhs_r, soln_r);
00144 }
00145 else
00146 {
00147 state = solvers_[r].solve(op.getBlock(r,r), rhs_r, soln_r);
00148 }
00149 if (nRows > 1) soln.setBlock(r, soln_r);
00150 else soln = soln_r;
00151 if (state.finalState() != SolveConverged)
00152 {
00153 return state;
00154 }
00155 }
00156
00157 return SolverState<Scalar>(SolveConverged, "block solves converged",
00158 0, ScalarTraits<Scalar>::zero());
00159 }
00160
00161 }
00162
00163 #endif