|
Tpetra Matrix/Vector Services Version of the Day
|
00001 //@HEADER 00002 // ************************************************************************ 00003 // 00004 // Tpetra: Templated Linear Algebra Services Package 00005 // Copyright (2008) Sandia Corporation 00006 // 00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive 00008 // license for use of this work by or on behalf of the U.S. Government. 00009 // 00010 // This library is free software; you can redistribute it and/or modify 00011 // it under the terms of the GNU Lesser General Public License as 00012 // published by the Free Software Foundation; either version 2.1 of the 00013 // License, or (at your option) any later version. 00014 // 00015 // This library is distributed in the hope that it will be useful, but 00016 // WITHOUT ANY WARRANTY; without even the implied warranty of 00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00018 // Lesser General Public License for more details. 00019 // 00020 // You should have received a copy of the GNU Lesser General Public 00021 // License along with this library; if not, write to the Free Software 00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00023 // USA 00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 00025 // 00026 // ************************************************************************ 00027 //@HEADER 00028 00029 #ifndef TPETRA_CRSMATRIX_DECL_HPP 00030 #define TPETRA_CRSMATRIX_DECL_HPP 00031 00032 // TODO: row-wise insertion of entries in globalAssemble() may be more efficient 00033 00034 // TODO: add typeglobs: CrsMatrix<Scalar,typeglob> 00035 // TODO: add template (template) parameter for nonlocal container (this will be part of typeglob) 00036 00037 #include <Kokkos_DefaultNode.hpp> 00038 #include <Kokkos_DefaultKernels.hpp> 00039 #include <Kokkos_CrsMatrix.hpp> 00040 00041 #include "Tpetra_ConfigDefs.hpp" 00042 #include "Tpetra_RowMatrix.hpp" 00043 #include "Tpetra_DistObject.hpp" 00044 #include "Tpetra_CrsGraph.hpp" 00045 #include "Tpetra_Vector.hpp" 00046 #include "Tpetra_CrsMatrixMultiplyOp_decl.hpp" 00047 00053 #ifndef DOXYGEN_SHOULD_SKIP_THIS 00054 namespace Tpetra { 00055 // struct for i,j,v triplets 00056 template <class Ordinal, class Scalar> 00057 struct CrsIJV { 00058 CrsIJV(); 00059 CrsIJV(Ordinal row, Ordinal col, const Scalar &val); 00060 Ordinal i,j; 00061 Scalar v; 00062 }; 00063 } 00064 00065 namespace Teuchos { 00066 // SerializationTraits specialization for CrsIJV, using DirectSerialization 00067 template <typename Ordinal, typename Scalar> 00068 class SerializationTraits<int,Tpetra::CrsIJV<Ordinal,Scalar> > 00069 : public DirectSerializationTraits<int,Tpetra::CrsIJV<Ordinal,Scalar> > 00070 {}; 00071 } 00072 00073 namespace std { 00074 template <class Ordinal, class Scalar> 00075 bool operator<(const Tpetra::CrsIJV<Ordinal,Scalar> &ijv1, const Tpetra::CrsIJV<Ordinal,Scalar> &ijv2); 00076 } 00077 #endif 00078 00079 namespace Tpetra { 00080 00082 00108 template <class Scalar, 00109 class LocalOrdinal = int, 00110 class GlobalOrdinal = LocalOrdinal, 00111 class Node = Kokkos::DefaultNode::DefaultNodeType, 00112 class LocalMatOps = typename Kokkos::DefaultKernels<Scalar,LocalOrdinal,Node>::SparseOps > 00113 class CrsMatrix : public RowMatrix<Scalar,LocalOrdinal,GlobalOrdinal,Node>, 00114 public DistObject<char, LocalOrdinal,GlobalOrdinal,Node> { 00115 public: 00116 typedef Scalar scalar_type; 00117 typedef LocalOrdinal local_ordinal_type; 00118 typedef GlobalOrdinal global_ordinal_type; 00119 typedef Node node_type; 00120 // backwards compatibility defines both of these 00121 typedef LocalMatOps mat_vec_type; 00122 typedef LocalMatOps mat_solve_type; 00123 00125 00126 00128 CrsMatrix(const RCP<const Map<LocalOrdinal,GlobalOrdinal,Node> > &rowMap, size_t maxNumEntriesPerRow, ProfileType pftype = DynamicProfile); 00129 00131 CrsMatrix(const RCP<const Map<LocalOrdinal,GlobalOrdinal,Node> > &rowMap, const ArrayRCP<const size_t> &NumEntriesPerRowToAlloc, ProfileType pftype = DynamicProfile); 00132 00134 00136 CrsMatrix(const RCP<const Map<LocalOrdinal,GlobalOrdinal,Node> > &rowMap, const RCP<const Map<LocalOrdinal,GlobalOrdinal,Node> > &colMap, size_t maxNumEntriesPerRow, ProfileType pftype = DynamicProfile); 00137 00139 00141 CrsMatrix(const RCP<const Map<LocalOrdinal,GlobalOrdinal,Node> > &rowMap, const RCP<const Map<LocalOrdinal,GlobalOrdinal,Node> > &colMap, const ArrayRCP<const size_t> &NumEntriesPerRowToAlloc, ProfileType pftype = DynamicProfile); 00142 00144 explicit CrsMatrix(const RCP<const CrsGraph<LocalOrdinal,GlobalOrdinal,Node,LocalMatOps> > &graph); 00145 00146 // !Destructor. 00147 virtual ~CrsMatrix(); 00148 00150 00152 00153 00155 00166 void insertGlobalValues(GlobalOrdinal globalRow, const ArrayView<const GlobalOrdinal> &cols, const ArrayView<const Scalar> &vals); 00167 00169 00180 void insertLocalValues(LocalOrdinal localRow, const ArrayView<const LocalOrdinal> &cols, const ArrayView<const Scalar> &vals); 00181 00183 00188 void replaceGlobalValues(GlobalOrdinal globalRow, 00189 const ArrayView<const GlobalOrdinal> &cols, 00190 const ArrayView<const Scalar> &vals); 00191 00193 00195 void replaceLocalValues(LocalOrdinal localRow, 00196 const ArrayView<const LocalOrdinal> &cols, 00197 const ArrayView<const Scalar> &vals); 00198 00200 00205 void sumIntoGlobalValues(GlobalOrdinal globalRow, 00206 const ArrayView<const GlobalOrdinal> &cols, 00207 const ArrayView<const Scalar> &vals); 00208 00209 00211 00216 void sumIntoLocalValues(LocalOrdinal globalRow, 00217 const ArrayView<const LocalOrdinal> &cols, 00218 const ArrayView<const Scalar> &vals); 00219 00221 void setAllToScalar(const Scalar &alpha); 00222 00224 void scale(const Scalar &alpha); 00225 00227 00229 00230 00232 void globalAssemble(); 00233 00242 void resumeFill(); 00243 00255 void fillComplete(const RCP<const Map<LocalOrdinal,GlobalOrdinal,Node> > &domainMap, const RCP<const Map<LocalOrdinal,GlobalOrdinal,Node> > &rangeMap, OptimizeOption os = DoOptimizeStorage); 00256 00270 void fillComplete(OptimizeOption os = DoOptimizeStorage); 00271 00273 00275 00276 00278 const RCP<const Comm<int> > & getComm() const; 00279 00281 RCP<Node> getNode() const; 00282 00284 const RCP<const Map<LocalOrdinal,GlobalOrdinal,Node> > & getRowMap() const; 00285 00287 const RCP<const Map<LocalOrdinal,GlobalOrdinal,Node> > & getColMap() const; 00288 00290 RCP<const RowGraph<LocalOrdinal,GlobalOrdinal,Node> > getGraph() const; 00291 00293 RCP<const CrsGraph<LocalOrdinal,GlobalOrdinal,Node,LocalMatOps> > getCrsGraph() const; 00294 00296 00298 global_size_t getGlobalNumRows() const; 00299 00301 00303 global_size_t getGlobalNumCols() const; 00304 00306 size_t getNodeNumRows() const; 00307 00309 00311 size_t getNodeNumCols() const; 00312 00314 GlobalOrdinal getIndexBase() const; 00315 00317 global_size_t getGlobalNumEntries() const; 00318 00320 size_t getNodeNumEntries() const; 00321 00323 00324 size_t getNumEntriesInGlobalRow(GlobalOrdinal globalRow) const; 00325 00327 00328 size_t getNumEntriesInLocalRow(LocalOrdinal localRow) const; 00329 00331 00333 global_size_t getGlobalNumDiags() const; 00334 00336 00338 size_t getNodeNumDiags() const; 00339 00341 00343 size_t getGlobalMaxNumRowEntries() const; 00344 00346 00348 size_t getNodeMaxNumRowEntries() const; 00349 00351 bool hasColMap() const; 00352 00354 00356 bool isLowerTriangular() const; 00357 00359 00361 bool isUpperTriangular() const; 00362 00364 bool isLocallyIndexed() const; 00365 00367 bool isGloballyIndexed() const; 00368 00370 bool isFillComplete() const; 00371 00373 bool isFillActive() const; 00374 00376 00382 bool isStorageOptimized() const; 00383 00385 ProfileType getProfileType() const; 00386 00388 bool isStaticGraph() const; 00389 00391 00401 void getGlobalRowCopy(GlobalOrdinal GlobalRow, 00402 const ArrayView<GlobalOrdinal> &Indices, 00403 const ArrayView<Scalar> &Values, 00404 size_t &NumEntries 00405 ) const; 00406 00408 00420 void getLocalRowCopy(LocalOrdinal LocalRow, 00421 const ArrayView<LocalOrdinal> &Indices, 00422 const ArrayView<Scalar> &Values, 00423 size_t &NumEntries 00424 ) const; 00425 00427 00436 void getGlobalRowView(GlobalOrdinal GlobalRow, ArrayView<const GlobalOrdinal> &indices, ArrayView<const Scalar> &values) const; 00437 00439 00448 void getLocalRowView(LocalOrdinal LocalRow, ArrayView<const LocalOrdinal> &indices, ArrayView<const Scalar> &values) const; 00449 00451 00453 void getLocalDiagCopy(Vector<Scalar,LocalOrdinal,GlobalOrdinal,Node> &diag) const; 00454 00456 00458 00459 00461 00470 template <class DomainScalar, class RangeScalar> 00471 void multiply(const MultiVector<DomainScalar,LocalOrdinal,GlobalOrdinal,Node> & X, MultiVector<RangeScalar,LocalOrdinal,GlobalOrdinal,Node> &Y, Teuchos::ETransp trans, RangeScalar alpha, RangeScalar beta) const; 00472 00474 00480 template <class DomainScalar, class RangeScalar> 00481 void solve(const MultiVector<RangeScalar,LocalOrdinal,GlobalOrdinal,Node> & Y, MultiVector<DomainScalar,LocalOrdinal,GlobalOrdinal,Node> &X, Teuchos::ETransp trans) const; 00482 00484 00486 00487 00489 00492 void apply(const MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> & X, MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> &Y, 00493 Teuchos::ETransp mode = Teuchos::NO_TRANS, 00494 Scalar alpha = ScalarTraits<Scalar>::one(), 00495 Scalar beta = ScalarTraits<Scalar>::zero()) const; 00496 00498 bool hasTransposeApply() const; 00499 00502 const RCP<const Map<LocalOrdinal,GlobalOrdinal,Node> > & getDomainMap() const; 00503 00506 const RCP<const Map<LocalOrdinal,GlobalOrdinal,Node> > & getRangeMap() const; 00507 00509 00511 00512 00514 std::string description() const; 00515 00517 void describe(Teuchos::FancyOStream &out, const Teuchos::EVerbosityLevel verbLevel=Teuchos::Describable::verbLevel_default) const; 00518 00520 00522 00523 00524 bool checkSizes(const DistObject<char, LocalOrdinal,GlobalOrdinal,Node>& source); 00525 00526 void copyAndPermute(const DistObject<char, LocalOrdinal,GlobalOrdinal,Node>& source, 00527 size_t numSameIDs, 00528 const ArrayView<const LocalOrdinal> &permuteToLIDs, 00529 const ArrayView<const LocalOrdinal> &permuteFromLIDs); 00530 00531 void packAndPrepare(const DistObject<char, LocalOrdinal,GlobalOrdinal,Node>& source, 00532 const ArrayView<const LocalOrdinal> &exportLIDs, 00533 Array<char> &exports, 00534 const ArrayView<size_t> & numPacketsPerLID, 00535 size_t& constantNumPackets, 00536 Distributor &distor); 00537 00538 void unpackAndCombine(const ArrayView<const LocalOrdinal> &importLIDs, 00539 const ArrayView<const char> &imports, 00540 const ArrayView<size_t> &numPacketsPerLID, 00541 size_t constantNumPackets, 00542 Distributor &distor, 00543 CombineMode CM); 00544 00546 00548 00549 00558 TPETRA_DEPRECATED void optimizeStorage(); 00559 00561 TPETRA_DEPRECATED void getGlobalRowView(GlobalOrdinal GlobalRow, ArrayRCP<const GlobalOrdinal> &indices, ArrayRCP<const Scalar> &values) const; 00562 00564 TPETRA_DEPRECATED void getLocalRowView(LocalOrdinal LocalRow, ArrayRCP<const LocalOrdinal> &indices, ArrayRCP<const Scalar> &values) const; 00565 00567 00568 private: 00569 // copy constructor disabled 00570 CrsMatrix(const CrsMatrix<Scalar,LocalOrdinal,GlobalOrdinal,Node,LocalMatOps> &Source); 00571 // operator= disabled 00572 CrsMatrix<Scalar,LocalOrdinal,GlobalOrdinal,Node,LocalMatOps> & operator=(const CrsMatrix<Scalar,LocalOrdinal,GlobalOrdinal,Node,LocalMatOps> &rhs); 00573 protected: 00574 // useful typedefs 00575 typedef OrdinalTraits<LocalOrdinal> LOT; 00576 typedef OrdinalTraits<GlobalOrdinal> GOT; 00577 typedef ScalarTraits<Scalar> ST; 00578 typedef MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> MV; 00579 typedef Vector<Scalar,LocalOrdinal,GlobalOrdinal,Node> V; 00580 typedef CrsGraph<LocalOrdinal,GlobalOrdinal,Node,LocalMatOps> Graph; 00581 // Enums 00582 enum GraphAllocationStatus { 00583 GraphAlreadyAllocated, 00584 GraphNotYetAllocated 00585 }; 00586 // Allocation 00587 void allocateValues(ELocalGlobal lg, GraphAllocationStatus gas); 00588 // Sorting and merging 00589 void sortEntries(); 00590 void mergeRedundantEntries(); 00591 // global consts 00592 void clearGlobalConstants(); 00593 void computeGlobalConstants(); 00594 // matrix data accessors 00595 ArrayView<const Scalar> getView(RowInfo rowinfo) const; 00596 ArrayView< Scalar> getViewNonConst(RowInfo rowinfo); 00597 // local Kokkos objects 00598 void pushToLocalMatrix(); 00599 void pullFromLocalMatrix(); 00600 void fillLocalMatrix(OptimizeOption os); 00601 void fillLocalSparseOps(); 00602 // debugging 00603 void checkInternalState() const; 00604 00605 // Two graph pointers needed in order to maintain const-correctness: 00606 // staticGraph_ is a graph passed to the constructor. We are not allowed to modify it. it is always a valid pointer. 00607 // myGraph_ is a graph created here. We are allowed to modify it. if myGraph_ != null, then staticGraph_ = myGraph_ 00608 RCP<const Graph> staticGraph_; 00609 RCP< Graph> myGraph_; 00610 00611 Kokkos::CrsMatrix<Scalar,LocalOrdinal,Node,LocalMatOps> lclMatrix_; 00612 typename LocalMatOps::template rebind<Scalar>::other lclMatOps_; 00613 00614 // matrix values. before allocation, both are null. 00615 // after allocation, one is null. 00616 // 1D == StaticAllocation, 2D == DynamicAllocation 00617 // The allocation always matches that of graph_, as the graph does the allocation for the matrix. 00618 ArrayRCP<Scalar> values1D_; 00619 ArrayRCP<ArrayRCP<Scalar> > values2D_; 00620 // TODO: these could be allocated at resumeFill() and de-allocated at fillComplete() to make for very fast getView()/getViewNonConst() 00621 // ArrayRCP< typedef ArrayRCP<const Scalar>::iterator > rowPtrs_; 00622 // ArrayRCP< typedef ArrayRCP< Scalar>::iterator > rowPtrsNC_; 00623 00624 bool fillComplete_; 00625 00626 // non-local data 00627 std::map<GlobalOrdinal, Array<std::pair<GlobalOrdinal,Scalar> > > nonlocals_; 00628 00629 // a wrapper around multiply, for use in apply; it contains a non-owning RCP to *this, therefore, it is not allowed 00630 // to persist past the destruction of *this. therefore, WE MAY NOT SHARE THIS POINTER. 00631 RCP< const CrsMatrixMultiplyOp<Scalar,Scalar,LocalOrdinal,GlobalOrdinal,Node,LocalMatOps> > sameScalarMultiplyOp_; 00632 00633 }; // class CrsMatrix 00634 00635 } // namespace Tpetra 00636 00637 #endif
1.7.4