|
EpetraExt Development
|
00001 //@HEADER 00002 // *********************************************************************** 00003 // 00004 // New_Package Example Package 00005 // Copyright (2004) Sandia Corporation 00006 // 00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive 00008 // license for use of this work by or on behalf of the U.S. Government. 00009 // 00010 // This library is free software; you can redistribute it and/or modify 00011 // it under the terms of the GNU Lesser General Public License as 00012 // published by the Free Software Foundation; either version 2.1 of the 00013 // License, or (at your option) any later version. 00014 // 00015 // This library is distributed in the hope that it will be useful, but 00016 // WITHOUT ANY WARRANTY; without even the implied warranty of 00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00018 // Lesser General Public License for more details. 00019 // 00020 // You should have received a copy of the GNU Lesser General Public 00021 // License along with this library; if not, write to the Free Software 00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00023 // USA 00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 00025 // 00026 // *********************************************************************** 00027 //@HEADER 00028 00029 #include "EpetraExt_PutRowMatrix.h" 00030 #include "Epetra_Comm.h" 00031 #include "Epetra_Map.h" 00032 #include "Epetra_Vector.h" 00033 #include "Epetra_IntVector.h" 00034 #include "Epetra_SerialDenseVector.h" 00035 #include "Epetra_IntSerialDenseVector.h" 00036 #include "Epetra_Import.h" 00037 #include "Epetra_RowMatrix.h" 00038 #include "Epetra_CrsMatrix.h" 00039 00040 using namespace Matlab; 00041 namespace Matlab { 00042 00043 int CopyRowMatrix(mxArray* matlabA, const Epetra_RowMatrix& A) { 00044 int valueCount = 0; 00045 //int* valueCount = &temp; 00046 00047 Epetra_Map map = A.RowMatrixRowMap(); 00048 const Epetra_Comm & comm = map.Comm(); 00049 int numProc = comm.NumProc(); 00050 00051 if (numProc==1) 00052 DoCopyRowMatrix(matlabA, valueCount, A); 00053 else { 00054 int numRows = map.NumMyElements(); 00055 00056 //cout << "creating allGidsMap\n"; 00057 Epetra_Map allGidsMap(-1, numRows, 0,comm); 00058 //cout << "done creating allGidsMap\n"; 00059 00060 Epetra_IntVector allGids(allGidsMap); 00061 for (int i=0; i<numRows; i++) allGids[i] = map.GID(i); 00062 00063 // Now construct a RowMatrix on PE 0 by strip-mining the rows of the input matrix A. 00064 int numChunks = numProc; 00065 int stripSize = allGids.GlobalLength()/numChunks; 00066 int remainder = allGids.GlobalLength()%numChunks; 00067 int curStart = 0; 00068 int curStripSize = 0; 00069 Epetra_IntSerialDenseVector importGidList; 00070 int numImportGids = 0; 00071 if (comm.MyPID()==0) 00072 importGidList.Size(stripSize+1); // Set size of vector to max needed 00073 for (int i=0; i<numChunks; i++) { 00074 if (comm.MyPID()==0) { // Only PE 0 does this part 00075 curStripSize = stripSize; 00076 if (i<remainder) curStripSize++; // handle leftovers 00077 for (int j=0; j<curStripSize; j++) importGidList[j] = j + curStart; 00078 curStart += curStripSize; 00079 } 00080 // The following import map will be non-trivial only on PE 0. 00081 //cout << "creating importGidMap\n"; 00082 Epetra_Map importGidMap(-1, curStripSize, importGidList.Values(), 0, comm); 00083 //cout << "done creating importGidMap\n"; 00084 Epetra_Import gidImporter(importGidMap, allGidsMap); 00085 Epetra_IntVector importGids(importGidMap); 00086 if (importGids.Import(allGids, gidImporter, Insert)) return(-1); 00087 00088 // importGids now has a list of GIDs for the current strip of matrix rows. 00089 // Use these values to build another importer that will get rows of the matrix. 00090 00091 // The following import map will be non-trivial only on PE 0. 00092 //cout << "creating importMap\n"; 00093 //cout << "A.RowMatrixRowMap().MinAllGID: " << A.RowMatrixRowMap().MinAllGID() << "\n"; 00094 Epetra_Map importMap(-1, importGids.MyLength(), importGids.Values(), A.RowMatrixRowMap().MinAllGID(), comm); 00095 //cout << "done creating importMap\n"; 00096 Epetra_Import importer(importMap, map); 00097 Epetra_CrsMatrix importA(Copy, importMap, 0); 00098 if (importA.Import(A, importer, Insert)) return(-1); 00099 if (importA.FillComplete()) return(-1); 00100 00101 // Finally we are ready to write this strip of the matrix to ostream 00102 if (DoCopyRowMatrix(matlabA, valueCount, importA)) return(-1); 00103 } 00104 } 00105 00106 if (A.RowMatrixRowMap().Comm().MyPID() == 0) { 00107 // set max cap 00108 int* matlabAcolumnIndicesPtr = mxGetJc(matlabA); 00109 matlabAcolumnIndicesPtr[A.NumGlobalRows()] = valueCount; 00110 } 00111 00112 return(0); 00113 } 00114 00115 int DoCopyRowMatrix(mxArray* matlabA, int& valueCount, const Epetra_RowMatrix& A) { 00116 //cout << "doing DoCopyRowMatrix\n"; 00117 int ierr = 0; 00118 int numRows = A.NumGlobalRows(); 00119 //cout << "numRows: " << numRows << "\n"; 00120 Epetra_Map rowMap = A.RowMatrixRowMap(); 00121 Epetra_Map colMap = A.RowMatrixColMap(); 00122 int minAllGID = rowMap.MinAllGID(); 00123 00124 const Epetra_Comm & comm = rowMap.Comm(); 00125 //cout << "did global setup\n"; 00126 if (comm.MyPID()!=0) { 00127 if (A.NumMyRows()!=0) ierr = -1; 00128 if (A.NumMyCols()!=0) ierr = -1; 00129 } 00130 else { 00131 // declare and get initial values of all matlabA pointers 00132 double* matlabAvaluesPtr = mxGetPr(matlabA); 00133 int* matlabAcolumnIndicesPtr = mxGetJc(matlabA); 00134 int* matlabArowIndicesPtr = mxGetIr(matlabA); 00135 00136 // set all matlabA pointers to the proper offset 00137 matlabAvaluesPtr += valueCount; 00138 matlabArowIndicesPtr += valueCount; 00139 00140 if (numRows!=A.NumMyRows()) ierr = -1; 00141 Epetra_SerialDenseVector values(A.MaxNumEntries()); 00142 Epetra_IntSerialDenseVector indices(A.MaxNumEntries()); 00143 //cout << "did proc0 setup\n"; 00144 for (int i=0; i<numRows; i++) { 00145 //cout << "extracting a row\n"; 00146 int I = rowMap.GID(i); 00147 int numEntries = 0; 00148 if (A.ExtractMyRowCopy(i, values.Length(), numEntries, 00149 values.Values(), indices.Values())) return(-1); 00150 matlabAcolumnIndicesPtr[I - minAllGID] = valueCount; // set the starting index of column I 00151 double* serialValuesPtr = values.Values(); 00152 for (int j=0; j<numEntries; j++) { 00153 int J = colMap.GID(indices[j]); 00154 *matlabAvaluesPtr = *serialValuesPtr++; 00155 *matlabArowIndicesPtr = J; 00156 // increment matlabA pointers 00157 matlabAvaluesPtr++; 00158 matlabArowIndicesPtr++; 00159 valueCount++; 00160 } 00161 } 00162 //cout << "proc0 row extraction for this chunck is done\n"; 00163 } 00164 00165 /* 00166 if (comm.MyPID() == 0) { 00167 cout << "printing matlabA pointers\n"; 00168 double* matlabAvaluesPtr = mxGetPr(matlabA); 00169 int* matlabAcolumnIndicesPtr = mxGetJc(matlabA); 00170 int* matlabArowIndicesPtr = mxGetIr(matlabA); 00171 for(int i=0; i < numRows; i++) { 00172 for(int j=0; j < A.MaxNumEntries(); j++) { 00173 cout << "*matlabAvaluesPtr: " << *matlabAvaluesPtr++ << " *matlabAcolumnIndicesPtr: " << *matlabAcolumnIndicesPtr++ << " *matlabArowIndicesPtr" << *matlabArowIndicesPtr++ << "\n"; 00174 } 00175 } 00176 00177 cout << "done printing matlabA pointers\n"; 00178 } 00179 */ 00180 00181 int ierrGlobal; 00182 comm.MinAll(&ierr, &ierrGlobal, 1); // If any processor has -1, all return -1 00183 return(ierrGlobal); 00184 } 00185 00186 } // namespace Matlab
1.7.4