|
Anasazi Version of the Day
|
00001 // @HEADER 00002 // *********************************************************************** 00003 // 00004 // Anasazi: Block Eigensolvers Package 00005 // Copyright (2010) Sandia Corporation 00006 // 00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive 00008 // license for use of this work by or on behalf of the U.S. Government. 00009 // 00010 // This library is free software; you can redistribute it and/or modify 00011 // it under the terms of the GNU Lesser General Public License as 00012 // published by the Free Software Foundation; either version 2.1 of the 00013 // License, or (at your option) any later version. 00014 // 00015 // This library is distributed in the hope that it will be useful, but 00016 // WITHOUT ANY WARRANTY; without even the implied warranty of 00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00018 // Lesser General Public License for more details. 00019 // 00020 // You should have received a copy of the GNU Lesser General Public 00021 // License along with this library; if not, write to the Free Software 00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00023 // USA 00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 00025 // 00026 // *********************************************************************** 00027 // @HEADER 00028 #ifndef __TSQR_CombineNative_hpp 00029 #define __TSQR_CombineNative_hpp 00030 00031 #include "Tsqr_ApplyType.hpp" 00032 #include "Tsqr_ScalarTraits.hpp" 00033 #include "Tsqr_Blas.hpp" 00034 #include "Tsqr_Lapack.hpp" 00035 #include "Tsqr_CombineDefault.hpp" 00036 00037 // #define TSQR_COMBINE_NATIVE_DEBUG 1 00038 #ifdef TSQR_COMBINE_NATIVE_DEBUG 00039 #include "Tsqr_Util.hpp" 00040 #include <iostream> 00041 using std::cerr; 00042 using std::endl; 00043 #endif // TSQR_COMBINE_NATIVE_DEBUG 00044 00047 00048 namespace TSQR { 00049 00064 template< class Ordinal, class Scalar, bool isComplex = ScalarTraits< Scalar >::is_complex > 00065 class CombineNative 00066 { 00067 public: 00068 typedef Scalar scalar_type; 00069 typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type; 00070 typedef Ordinal ordinal_type; 00071 00072 private: 00073 typedef BLAS< ordinal_type, scalar_type > blas_type; 00074 typedef LAPACK< ordinal_type, scalar_type > lapack_type; 00075 typedef CombineDefault< ordinal_type, scalar_type > combine_default_type; 00076 00077 public: 00078 00079 CombineNative () {} 00080 00088 static bool QR_produces_R_factor_with_nonnegative_diagonal() { 00089 return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() && 00090 combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal(); 00091 } 00092 00093 void 00094 factor_first (const Ordinal nrows, 00095 const Ordinal ncols, 00096 Scalar A[], 00097 const Ordinal lda, 00098 Scalar tau[], 00099 Scalar work[]) const 00100 { 00101 return default_.factor_first (nrows, ncols, A, lda, tau, work); 00102 } 00103 00104 void 00105 apply_first (const ApplyType& applyType, 00106 const Ordinal nrows, 00107 const Ordinal ncols_C, 00108 const Ordinal ncols_A, 00109 const Scalar A[], 00110 const Ordinal lda, 00111 const Scalar tau[], 00112 Scalar C[], 00113 const Ordinal ldc, 00114 Scalar work[]) const 00115 { 00116 return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 00117 A, lda, tau, C, ldc, work); 00118 } 00119 00120 void 00121 apply_inner (const ApplyType& applyType, 00122 const Ordinal m, 00123 const Ordinal ncols_C, 00124 const Ordinal ncols_Q, 00125 const Scalar A[], 00126 const Ordinal lda, 00127 const Scalar tau[], 00128 Scalar C_top[], 00129 const Ordinal ldc_top, 00130 Scalar C_bot[], 00131 const Ordinal ldc_bot, 00132 Scalar work[]) const; 00133 00134 void 00135 factor_inner (const Ordinal m, 00136 const Ordinal n, 00137 Scalar R[], 00138 const Ordinal ldr, 00139 Scalar A[], 00140 const Ordinal lda, 00141 Scalar tau[], 00142 Scalar work[]) const; 00143 00144 void 00145 factor_pair (const Ordinal n, 00146 Scalar R_top[], 00147 const Ordinal ldr_top, 00148 Scalar R_bot[], 00149 const Ordinal ldr_bot, 00150 Scalar tau[], 00151 Scalar work[]) const; 00152 00153 void 00154 apply_pair (const ApplyType& applyType, 00155 const Ordinal ncols_C, 00156 const Ordinal ncols_Q, 00157 const Scalar R_bot[], 00158 const Ordinal ldr_bot, 00159 const Scalar tau[], 00160 Scalar C_top[], 00161 const Ordinal ldc_top, 00162 Scalar C_bot[], 00163 const Ordinal ldc_bot, 00164 Scalar work[]) const; 00165 00166 private: 00167 mutable combine_default_type default_; 00168 }; 00169 00170 00173 template< class Ordinal, class Scalar > 00174 class CombineNative< Ordinal, Scalar, false > 00175 { 00176 public: 00177 typedef Scalar scalar_type; 00178 typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type; 00179 typedef Ordinal ordinal_type; 00180 00181 private: 00182 typedef BLAS< ordinal_type, scalar_type > blas_type; 00183 typedef LAPACK< ordinal_type, scalar_type > lapack_type; 00184 typedef CombineDefault< ordinal_type, scalar_type > combine_default_type; 00185 00186 public: 00187 CombineNative () {} 00188 00189 static bool QR_produces_R_factor_with_nonnegative_diagonal() { 00190 return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() && 00191 combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal(); 00192 } 00193 00194 void 00195 factor_first (const Ordinal nrows, 00196 const Ordinal ncols, 00197 Scalar A[], 00198 const Ordinal lda, 00199 Scalar tau[], 00200 Scalar work[]) const 00201 { 00202 return default_.factor_first (nrows, ncols, A, lda, tau, work); 00203 } 00204 00205 void 00206 apply_first (const ApplyType& applyType, 00207 const Ordinal nrows, 00208 const Ordinal ncols_C, 00209 const Ordinal ncols_A, 00210 const Scalar A[], 00211 const Ordinal lda, 00212 const Scalar tau[], 00213 Scalar C[], 00214 const Ordinal ldc, 00215 Scalar work[]) const 00216 { 00217 return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 00218 A, lda, tau, C, ldc, work); 00219 } 00220 00221 void 00222 apply_inner (const ApplyType& applyType, 00223 const Ordinal m, 00224 const Ordinal ncols_C, 00225 const Ordinal ncols_Q, 00226 const Scalar A[], 00227 const Ordinal lda, 00228 const Scalar tau[], 00229 Scalar C_top[], 00230 const Ordinal ldc_top, 00231 Scalar C_bot[], 00232 const Ordinal ldc_bot, 00233 Scalar work[]) const; 00234 00235 void 00236 factor_inner (const Ordinal m, 00237 const Ordinal n, 00238 Scalar R[], 00239 const Ordinal ldr, 00240 Scalar A[], 00241 const Ordinal lda, 00242 Scalar tau[], 00243 Scalar work[]) const; 00244 00245 void 00246 factor_pair (const Ordinal n, 00247 Scalar R_top[], 00248 const Ordinal ldr_top, 00249 Scalar R_bot[], 00250 const Ordinal ldr_bot, 00251 Scalar tau[], 00252 Scalar work[]) const; 00253 00254 void 00255 apply_pair (const ApplyType& applyType, 00256 const Ordinal ncols_C, 00257 const Ordinal ncols_Q, 00258 const Scalar R_bot[], 00259 const Ordinal ldr_bot, 00260 const Scalar tau[], 00261 Scalar C_top[], 00262 const Ordinal ldc_top, 00263 Scalar C_bot[], 00264 const Ordinal ldc_bot, 00265 Scalar work[]) const; 00266 00267 private: 00268 mutable combine_default_type default_; 00269 }; 00270 00271 00274 template< class Ordinal, class Scalar > 00275 class CombineNative< Ordinal, Scalar, true > 00276 { 00277 public: 00278 typedef Scalar scalar_type; 00279 typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type; 00280 typedef Ordinal ordinal_type; 00281 00282 private: 00283 typedef BLAS< ordinal_type, scalar_type > blas_type; 00284 typedef LAPACK< ordinal_type, scalar_type > lapack_type; 00285 typedef CombineDefault< ordinal_type, scalar_type > combine_default_type; 00286 00287 public: 00288 CombineNative () {} 00289 00290 static bool QR_produces_R_factor_with_nonnegative_diagonal() { 00291 return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() && 00292 combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal(); 00293 } 00294 00295 void 00296 factor_first (const Ordinal nrows, 00297 const Ordinal ncols, 00298 Scalar A[], 00299 const Ordinal lda, 00300 Scalar tau[], 00301 Scalar work[]) const 00302 { 00303 return default_.factor_first (nrows, ncols, A, lda, tau, work); 00304 } 00305 00306 void 00307 apply_first (const ApplyType& applyType, 00308 const Ordinal nrows, 00309 const Ordinal ncols_C, 00310 const Ordinal ncols_A, 00311 const Scalar A[], 00312 const Ordinal lda, 00313 const Scalar tau[], 00314 Scalar C[], 00315 const Ordinal ldc, 00316 Scalar work[]) const 00317 { 00318 return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 00319 A, lda, tau, C, ldc, work); 00320 } 00321 00322 void 00323 apply_inner (const ApplyType& applyType, 00324 const Ordinal m, 00325 const Ordinal ncols_C, 00326 const Ordinal ncols_Q, 00327 const Scalar A[], 00328 const Ordinal lda, 00329 const Scalar tau[], 00330 Scalar C_top[], 00331 const Ordinal ldc_top, 00332 Scalar C_bot[], 00333 const Ordinal ldc_bot, 00334 Scalar work[]) const 00335 { 00336 return default_.apply_inner (applyType, m, ncols_C, ncols_Q, 00337 A, lda, tau, 00338 C_top, ldc_top, C_bot, ldc_bot, 00339 work); 00340 } 00341 00342 void 00343 factor_inner (const Ordinal m, 00344 const Ordinal n, 00345 Scalar R[], 00346 const Ordinal ldr, 00347 Scalar A[], 00348 const Ordinal lda, 00349 Scalar tau[], 00350 Scalar work[]) const 00351 { 00352 return default_.factor_inner (m, n, R, ldr, A, lda, tau, work); 00353 } 00354 00355 void 00356 factor_pair (const Ordinal n, 00357 Scalar R_top[], 00358 const Ordinal ldr_top, 00359 Scalar R_bot[], 00360 const Ordinal ldr_bot, 00361 Scalar tau[], 00362 Scalar work[]) const 00363 { 00364 return default_.factor_pair (n, R_top, ldr_top, R_bot, ldr_bot, tau, work); 00365 } 00366 00367 void 00368 apply_pair (const ApplyType& applyType, 00369 const Ordinal ncols_C, 00370 const Ordinal ncols_Q, 00371 const Scalar R_bot[], 00372 const Ordinal ldr_bot, 00373 const Scalar tau[], 00374 Scalar C_top[], 00375 const Ordinal ldc_top, 00376 Scalar C_bot[], 00377 const Ordinal ldc_bot, 00378 Scalar work[]) const 00379 { 00380 return default_.apply_pair (applyType, ncols_C, ncols_Q, 00381 R_bot, ldr_bot, tau, 00382 C_top, ldc_top, C_bot, ldc_bot, 00383 work); 00384 } 00385 00386 private: 00387 mutable combine_default_type default_; 00388 }; 00389 00390 00391 template< class Ordinal, class Scalar > 00392 void 00393 CombineNative< Ordinal, Scalar, false >:: 00394 factor_inner (const Ordinal m, 00395 const Ordinal n, 00396 Scalar R[], 00397 const Ordinal ldr, 00398 Scalar A[], 00399 const Ordinal lda, 00400 Scalar tau[], 00401 Scalar work[]) const 00402 { 00403 const Scalar ZERO(0), ONE(1); 00404 lapack_type lapack; 00405 blas_type blas; 00406 00407 for (Ordinal k = 0; k < n; ++k) 00408 work[k] = ZERO; 00409 00410 for (Ordinal k = 0; k < n-1; ++k) 00411 { 00412 Scalar& R_kk = R[ k + k * ldr ]; 00413 Scalar* const A_1k = &A[ 0 + k * lda ]; 00414 Scalar* const A_1kp1 = &A[ 0 + (k+1) * lda ]; 00415 00416 lapack.LARFP (m + 1, R_kk, A_1k, 1, tau[k]); 00417 blas.GEMV ("T", m, n-k-1, ONE, A_1kp1, lda, A_1k, 1, ZERO, work, 1); 00418 00419 for (Ordinal j = k+1; j < n; ++j) 00420 { 00421 Scalar& R_kj = R[ k + j*ldr ]; 00422 00423 work[j-k-1] += R_kj; 00424 R_kj -= tau[k] * work[j-k-1]; 00425 } 00426 blas.GER (m, n-k-1, -tau[k], A_1k, 1, work, 1, A_1kp1, lda); 00427 } 00428 Scalar& R_nn = R[ (n-1) + (n-1) * ldr ]; 00429 Scalar* const A_1n = &A[ 0 + (n-1) * lda ]; 00430 00431 lapack.LARFP (m+1, R_nn, A_1n, 1, tau[n-1]); 00432 } 00433 00434 00435 template< class Ordinal, class Scalar > 00436 void 00437 CombineNative< Ordinal, Scalar, false >:: 00438 apply_inner (const ApplyType& applyType, 00439 const Ordinal m, 00440 const Ordinal ncols_C, 00441 const Ordinal ncols_Q, 00442 const Scalar A[], 00443 const Ordinal lda, 00444 const Scalar tau[], 00445 Scalar C_top[], 00446 const Ordinal ldc_top, 00447 Scalar C_bot[], 00448 const Ordinal ldc_bot, 00449 Scalar work[]) const 00450 { 00451 const Scalar ZERO(0); 00452 blas_type blas; 00453 00454 //Scalar* const y = work; 00455 for (Ordinal i = 0; i < ncols_C; ++i) 00456 work[i] = ZERO; 00457 00458 Ordinal j_start, j_end, j_step; 00459 if (applyType == ApplyType::NoTranspose) 00460 { 00461 j_start = ncols_Q - 1; 00462 j_end = -1; // exclusive 00463 j_step = -1; 00464 } 00465 else 00466 { 00467 j_start = 0; 00468 j_end = ncols_Q; // exclusive 00469 j_step = +1; 00470 } 00471 for (Ordinal j = j_start; j != j_end; j += j_step) 00472 { 00473 const Scalar* const A_1j = &A[ 0 + j*lda ]; 00474 00475 //blas.GEMV ("T", m, ncols_C, ONE, C_bot, ldc_bot, A_1j, 1, ZERO, &y[0], 1); 00476 for (Ordinal i = 0; i < ncols_C; ++i) 00477 { 00478 work[i] = ZERO; 00479 for (Ordinal k = 0; k < m; ++k) 00480 work[i] += A_1j[k] * C_bot[ k + i*ldc_bot ]; 00481 00482 work[i] += C_top[ j + i*ldc_top ]; 00483 } 00484 for (Ordinal k = 0; k < ncols_C; ++k) 00485 C_top[ j + k*ldc_top ] -= tau[j] * work[k]; 00486 00487 blas.GER (m, ncols_C, -tau[j], A_1j, 1, work, 1, C_bot, ldc_bot); 00488 } 00489 } 00490 00491 00492 template< class Ordinal, class Scalar > 00493 void 00494 CombineNative< Ordinal, Scalar, false >:: 00495 factor_pair (const Ordinal n, 00496 Scalar R_top[], 00497 const Ordinal ldr_top, 00498 Scalar R_bot[], 00499 const Ordinal ldr_bot, 00500 Scalar tau[], 00501 Scalar work[]) const 00502 { 00503 const Scalar ZERO(0), ONE(1); 00504 lapack_type lapack; 00505 blas_type blas; 00506 00507 for (Ordinal k = 0; k < n; ++k) 00508 work[k] = ZERO; 00509 00510 for (Ordinal k = 0; k < n-1; ++k) 00511 { 00512 Scalar& R_top_kk = R_top[ k + k * ldr_top ]; 00513 Scalar* const R_bot_1k = &R_bot[ 0 + k * ldr_bot ]; 00514 Scalar* const R_bot_1kp1 = &R_bot[ 0 + (k+1) * ldr_bot ]; 00515 00516 // k+2: 1 element in R_top (R_top(k,k)), and k+1 elements in 00517 // R_bot (R_bot(1:k,k), in 1-based indexing notation). 00518 lapack.LARFP (k+2, R_top_kk, R_bot_1k, 1, tau[k]); 00519 // One-based indexing, Matlab version of the GEMV call below: 00520 // work(1:k) := R_bot(1:k,k+1:n)' * R_bot(1:k,k) 00521 blas.GEMV ("T", k+1, n-k-1, ONE, R_bot_1kp1, ldr_bot, R_bot_1k, 1, ZERO, work, 1); 00522 00523 for (Ordinal j = k+1; j < n; ++j) 00524 { 00525 Scalar& R_top_kj = R_top[ k + j*ldr_top ]; 00526 work[j-k-1] += R_top_kj; 00527 R_top_kj -= tau[k] * work[j-k-1]; 00528 } 00529 blas.GER (k+1, n-k-1, -tau[k], R_bot_1k, 1, work, 1, R_bot_1kp1, ldr_bot); 00530 } 00531 Scalar& R_top_nn = R_top[ (n-1) + (n-1)*ldr_top ]; 00532 Scalar* const R_bot_1n = &R_bot[ 0 + (n-1)*ldr_bot ]; 00533 00534 // n+1: 1 element in R_top (n,n), and n elements in R_bot (the 00535 // whole last column). 00536 lapack.LARFP (n+1, R_top_nn, R_bot_1n, 1, tau[n-1]); 00537 } 00538 00539 00540 template< class Ordinal, class Scalar > 00541 void 00542 CombineNative< Ordinal, Scalar, false >:: 00543 apply_pair (const ApplyType& applyType, 00544 const Ordinal ncols_C, 00545 const Ordinal ncols_Q, 00546 const Scalar R_bot[], 00547 const Ordinal ldr_bot, 00548 const Scalar tau[], 00549 Scalar C_top[], 00550 const Ordinal ldc_top, 00551 Scalar C_bot[], 00552 const Ordinal ldc_bot, 00553 Scalar work[]) const 00554 { 00555 const Scalar ZERO(0); 00556 blas_type blas; 00557 00558 for (Ordinal i = 0; i < ncols_C; ++i) 00559 work[i] = ZERO; 00560 00561 Ordinal j_start, j_end, j_step; 00562 if (applyType == ApplyType::NoTranspose) 00563 { 00564 j_start = ncols_Q - 1; 00565 j_end = -1; // exclusive 00566 j_step = -1; 00567 } 00568 else 00569 { 00570 j_start = 0; 00571 j_end = ncols_Q; // exclusive 00572 j_step = +1; 00573 } 00574 for (Ordinal j_Q = j_start; j_Q != j_end; j_Q += j_step) 00575 { // Using Householder reflector stored in column j_Q of R_bot 00576 const Scalar* const R_bot_col = &R_bot[ 0 + j_Q*ldr_bot ]; 00577 00578 // In 1-based indexing notation, with k in 1, 2, ..., ncols_C 00579 // (inclusive): (Output is length ncols_C row vector) 00580 // 00581 // work(1:j) := R_bot(1:j,j)' * C_bot(1:j, 1:ncols_C) - C_top(j, 1:ncols_C) 00582 for (Ordinal j_C = 0; j_C < ncols_C; ++j_C) 00583 { // For each column j_C of [C_top; C_bot], update row j_Q 00584 // of C_top and rows 1:j_Q of C_bot. (Again, this is in 00585 // 1-based indexing notation. 00586 00587 Scalar work_j_C = ZERO; 00588 const Scalar* const C_bot_col = &C_bot[ 0 + j_C*ldc_bot ]; 00589 00590 for (Ordinal k = 0; k <= j_Q; ++k) 00591 work_j_C += R_bot_col[k] * C_bot_col[k]; 00592 00593 work_j_C += C_top[ j_Q + j_C*ldc_top ]; 00594 work[j_C] = work_j_C; 00595 } 00596 for (Ordinal j_C = 0; j_C < ncols_C; ++j_C) 00597 C_top[ j_Q + j_C*ldc_top ] -= tau[j_Q] * work[j_C]; 00598 00599 blas.GER (j_Q+1, ncols_C, -tau[j_Q], R_bot_col, 1, work, 1, C_bot, ldc_bot); 00600 } 00601 } 00602 } // namespace TSQR 00603 00604 00605 00606 #endif // __TSQR_CombineNative_hpp
1.7.4