$treeview $search $mathjax
Eigen
3.2.5
$projectbrief
|
$projectbrief
|
$searchbox |
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2012 Désiré Nuentsa-Wakam <desire.nuentsa_wakam@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_PASTIXSUPPORT_H 00011 #define EIGEN_PASTIXSUPPORT_H 00012 00013 namespace Eigen { 00014 00023 template<typename _MatrixType, bool IsStrSym = false> class PastixLU; 00024 template<typename _MatrixType, int Options> class PastixLLT; 00025 template<typename _MatrixType, int Options> class PastixLDLT; 00026 00027 namespace internal 00028 { 00029 00030 template<class Pastix> struct pastix_traits; 00031 00032 template<typename _MatrixType> 00033 struct pastix_traits< PastixLU<_MatrixType> > 00034 { 00035 typedef _MatrixType MatrixType; 00036 typedef typename _MatrixType::Scalar Scalar; 00037 typedef typename _MatrixType::RealScalar RealScalar; 00038 typedef typename _MatrixType::Index Index; 00039 }; 00040 00041 template<typename _MatrixType, int Options> 00042 struct pastix_traits< PastixLLT<_MatrixType,Options> > 00043 { 00044 typedef _MatrixType MatrixType; 00045 typedef typename _MatrixType::Scalar Scalar; 00046 typedef typename _MatrixType::RealScalar RealScalar; 00047 typedef typename _MatrixType::Index Index; 00048 }; 00049 00050 template<typename _MatrixType, int Options> 00051 struct pastix_traits< PastixLDLT<_MatrixType,Options> > 00052 { 00053 typedef _MatrixType MatrixType; 00054 typedef typename _MatrixType::Scalar Scalar; 00055 typedef typename _MatrixType::RealScalar RealScalar; 00056 typedef typename _MatrixType::Index Index; 00057 }; 00058 00059 void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, float *vals, int *perm, int * invp, float *x, int nbrhs, int *iparm, double *dparm) 00060 { 00061 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; } 00062 if (nbrhs == 0) {x = NULL; nbrhs=1;} 00063 s_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm); 00064 } 00065 00066 void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, double *vals, int *perm, int * invp, double *x, int nbrhs, int *iparm, double *dparm) 00067 { 00068 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; } 00069 if (nbrhs == 0) {x = NULL; nbrhs=1;} 00070 d_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm); 00071 } 00072 00073 void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, std::complex<float> *vals, int *perm, int * invp, std::complex<float> *x, int nbrhs, int *iparm, double *dparm) 00074 { 00075 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; } 00076 if (nbrhs == 0) {x = NULL; nbrhs=1;} 00077 c_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<COMPLEX*>(vals), perm, invp, reinterpret_cast<COMPLEX*>(x), nbrhs, iparm, dparm); 00078 } 00079 00080 void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, std::complex<double> *vals, int *perm, int * invp, std::complex<double> *x, int nbrhs, int *iparm, double *dparm) 00081 { 00082 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; } 00083 if (nbrhs == 0) {x = NULL; nbrhs=1;} 00084 z_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<DCOMPLEX*>(vals), perm, invp, reinterpret_cast<DCOMPLEX*>(x), nbrhs, iparm, dparm); 00085 } 00086 00087 // Convert the matrix to Fortran-style Numbering 00088 template <typename MatrixType> 00089 void c_to_fortran_numbering (MatrixType& mat) 00090 { 00091 if ( !(mat.outerIndexPtr()[0]) ) 00092 { 00093 int i; 00094 for(i = 0; i <= mat.rows(); ++i) 00095 ++mat.outerIndexPtr()[i]; 00096 for(i = 0; i < mat.nonZeros(); ++i) 00097 ++mat.innerIndexPtr()[i]; 00098 } 00099 } 00100 00101 // Convert to C-style Numbering 00102 template <typename MatrixType> 00103 void fortran_to_c_numbering (MatrixType& mat) 00104 { 00105 // Check the Numbering 00106 if ( mat.outerIndexPtr()[0] == 1 ) 00107 { // Convert to C-style numbering 00108 int i; 00109 for(i = 0; i <= mat.rows(); ++i) 00110 --mat.outerIndexPtr()[i]; 00111 for(i = 0; i < mat.nonZeros(); ++i) 00112 --mat.innerIndexPtr()[i]; 00113 } 00114 } 00115 } 00116 00117 // This is the base class to interface with PaStiX functions. 00118 // Users should not used this class directly. 00119 template <class Derived> 00120 class PastixBase : internal::noncopyable 00121 { 00122 public: 00123 typedef typename internal::pastix_traits<Derived>::MatrixType _MatrixType; 00124 typedef _MatrixType MatrixType; 00125 typedef typename MatrixType::Scalar Scalar; 00126 typedef typename MatrixType::RealScalar RealScalar; 00127 typedef typename MatrixType::Index Index; 00128 typedef Matrix<Scalar,Dynamic,1> Vector; 00129 typedef SparseMatrix<Scalar, ColMajor> ColSpMatrix; 00130 00131 public: 00132 00133 PastixBase() : m_initisOk(false), m_analysisIsOk(false), m_factorizationIsOk(false), m_isInitialized(false), m_pastixdata(0), m_size(0) 00134 { 00135 init(); 00136 } 00137 00138 ~PastixBase() 00139 { 00140 clean(); 00141 } 00142 00147 template<typename Rhs> 00148 inline const internal::solve_retval<PastixBase, Rhs> 00149 solve(const MatrixBase<Rhs>& b) const 00150 { 00151 eigen_assert(m_isInitialized && "Pastix solver is not initialized."); 00152 eigen_assert(rows()==b.rows() 00153 && "PastixBase::solve(): invalid number of rows of the right hand side matrix b"); 00154 return internal::solve_retval<PastixBase, Rhs>(*this, b.derived()); 00155 } 00156 00157 template<typename Rhs,typename Dest> 00158 bool _solve (const MatrixBase<Rhs> &b, MatrixBase<Dest> &x) const; 00159 00160 Derived& derived() 00161 { 00162 return *static_cast<Derived*>(this); 00163 } 00164 const Derived& derived() const 00165 { 00166 return *static_cast<const Derived*>(this); 00167 } 00168 00174 Array<Index,IPARM_SIZE,1>& iparm() 00175 { 00176 return m_iparm; 00177 } 00178 00183 int& iparm(int idxparam) 00184 { 00185 return m_iparm(idxparam); 00186 } 00187 00192 Array<RealScalar,IPARM_SIZE,1>& dparm() 00193 { 00194 return m_dparm; 00195 } 00196 00197 00201 double& dparm(int idxparam) 00202 { 00203 return m_dparm(idxparam); 00204 } 00205 00206 inline Index cols() const { return m_size; } 00207 inline Index rows() const { return m_size; } 00208 00217 ComputationInfo info() const 00218 { 00219 eigen_assert(m_isInitialized && "Decomposition is not initialized."); 00220 return m_info; 00221 } 00222 00227 template<typename Rhs> 00228 inline const internal::sparse_solve_retval<PastixBase, Rhs> 00229 solve(const SparseMatrixBase<Rhs>& b) const 00230 { 00231 eigen_assert(m_isInitialized && "Pastix LU, LLT or LDLT is not initialized."); 00232 eigen_assert(rows()==b.rows() 00233 && "PastixBase::solve(): invalid number of rows of the right hand side matrix b"); 00234 return internal::sparse_solve_retval<PastixBase, Rhs>(*this, b.derived()); 00235 } 00236 00237 protected: 00238 00239 // Initialize the Pastix data structure, check the matrix 00240 void init(); 00241 00242 // Compute the ordering and the symbolic factorization 00243 void analyzePattern(ColSpMatrix& mat); 00244 00245 // Compute the numerical factorization 00246 void factorize(ColSpMatrix& mat); 00247 00248 // Free all the data allocated by Pastix 00249 void clean() 00250 { 00251 eigen_assert(m_initisOk && "The Pastix structure should be allocated first"); 00252 m_iparm(IPARM_START_TASK) = API_TASK_CLEAN; 00253 m_iparm(IPARM_END_TASK) = API_TASK_CLEAN; 00254 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0, 00255 m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data()); 00256 } 00257 00258 void compute(ColSpMatrix& mat); 00259 00260 int m_initisOk; 00261 int m_analysisIsOk; 00262 int m_factorizationIsOk; 00263 bool m_isInitialized; 00264 mutable ComputationInfo m_info; 00265 mutable pastix_data_t *m_pastixdata; // Data structure for pastix 00266 mutable int m_comm; // The MPI communicator identifier 00267 mutable Matrix<int,IPARM_SIZE,1> m_iparm; // integer vector for the input parameters 00268 mutable Matrix<double,DPARM_SIZE,1> m_dparm; // Scalar vector for the input parameters 00269 mutable Matrix<Index,Dynamic,1> m_perm; // Permutation vector 00270 mutable Matrix<Index,Dynamic,1> m_invp; // Inverse permutation vector 00271 mutable int m_size; // Size of the matrix 00272 }; 00273 00278 template <class Derived> 00279 void PastixBase<Derived>::init() 00280 { 00281 m_size = 0; 00282 m_iparm.setZero(IPARM_SIZE); 00283 m_dparm.setZero(DPARM_SIZE); 00284 00285 m_iparm(IPARM_MODIFY_PARAMETER) = API_NO; 00286 pastix(&m_pastixdata, MPI_COMM_WORLD, 00287 0, 0, 0, 0, 00288 0, 0, 0, 1, m_iparm.data(), m_dparm.data()); 00289 00290 m_iparm[IPARM_MATRIX_VERIFICATION] = API_NO; 00291 m_iparm[IPARM_VERBOSE] = 2; 00292 m_iparm[IPARM_ORDERING] = API_ORDER_SCOTCH; 00293 m_iparm[IPARM_INCOMPLETE] = API_NO; 00294 m_iparm[IPARM_OOC_LIMIT] = 2000; 00295 m_iparm[IPARM_RHS_MAKING] = API_RHS_B; 00296 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO; 00297 00298 m_iparm(IPARM_START_TASK) = API_TASK_INIT; 00299 m_iparm(IPARM_END_TASK) = API_TASK_INIT; 00300 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0, 00301 0, 0, 0, 0, m_iparm.data(), m_dparm.data()); 00302 00303 // Check the returned error 00304 if(m_iparm(IPARM_ERROR_NUMBER)) { 00305 m_info = InvalidInput; 00306 m_initisOk = false; 00307 } 00308 else { 00309 m_info = Success; 00310 m_initisOk = true; 00311 } 00312 } 00313 00314 template <class Derived> 00315 void PastixBase<Derived>::compute(ColSpMatrix& mat) 00316 { 00317 eigen_assert(mat.rows() == mat.cols() && "The input matrix should be squared"); 00318 00319 analyzePattern(mat); 00320 factorize(mat); 00321 00322 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO; 00323 m_isInitialized = m_factorizationIsOk; 00324 } 00325 00326 00327 template <class Derived> 00328 void PastixBase<Derived>::analyzePattern(ColSpMatrix& mat) 00329 { 00330 eigen_assert(m_initisOk && "The initialization of PaSTiX failed"); 00331 00332 // clean previous calls 00333 if(m_size>0) 00334 clean(); 00335 00336 m_size = mat.rows(); 00337 m_perm.resize(m_size); 00338 m_invp.resize(m_size); 00339 00340 m_iparm(IPARM_START_TASK) = API_TASK_ORDERING; 00341 m_iparm(IPARM_END_TASK) = API_TASK_ANALYSE; 00342 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(), 00343 mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data()); 00344 00345 // Check the returned error 00346 if(m_iparm(IPARM_ERROR_NUMBER)) 00347 { 00348 m_info = NumericalIssue; 00349 m_analysisIsOk = false; 00350 } 00351 else 00352 { 00353 m_info = Success; 00354 m_analysisIsOk = true; 00355 } 00356 } 00357 00358 template <class Derived> 00359 void PastixBase<Derived>::factorize(ColSpMatrix& mat) 00360 { 00361 // if(&m_cpyMat != &mat) m_cpyMat = mat; 00362 eigen_assert(m_analysisIsOk && "The analysis phase should be called before the factorization phase"); 00363 m_iparm(IPARM_START_TASK) = API_TASK_NUMFACT; 00364 m_iparm(IPARM_END_TASK) = API_TASK_NUMFACT; 00365 m_size = mat.rows(); 00366 00367 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(), 00368 mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data()); 00369 00370 // Check the returned error 00371 if(m_iparm(IPARM_ERROR_NUMBER)) 00372 { 00373 m_info = NumericalIssue; 00374 m_factorizationIsOk = false; 00375 m_isInitialized = false; 00376 } 00377 else 00378 { 00379 m_info = Success; 00380 m_factorizationIsOk = true; 00381 m_isInitialized = true; 00382 } 00383 } 00384 00385 /* Solve the system */ 00386 template<typename Base> 00387 template<typename Rhs,typename Dest> 00388 bool PastixBase<Base>::_solve (const MatrixBase<Rhs> &b, MatrixBase<Dest> &x) const 00389 { 00390 eigen_assert(m_isInitialized && "The matrix should be factorized first"); 00391 EIGEN_STATIC_ASSERT((Dest::Flags&RowMajorBit)==0, 00392 THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES); 00393 int rhs = 1; 00394 00395 x = b; /* on return, x is overwritten by the computed solution */ 00396 00397 for (int i = 0; i < b.cols(); i++){ 00398 m_iparm[IPARM_START_TASK] = API_TASK_SOLVE; 00399 m_iparm[IPARM_END_TASK] = API_TASK_REFINE; 00400 00401 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, x.rows(), 0, 0, 0, 00402 m_perm.data(), m_invp.data(), &x(0, i), rhs, m_iparm.data(), m_dparm.data()); 00403 } 00404 00405 // Check the returned error 00406 m_info = m_iparm(IPARM_ERROR_NUMBER)==0 ? Success : NumericalIssue; 00407 00408 return m_iparm(IPARM_ERROR_NUMBER)==0; 00409 } 00410 00430 template<typename _MatrixType, bool IsStrSym> 00431 class PastixLU : public PastixBase< PastixLU<_MatrixType> > 00432 { 00433 public: 00434 typedef _MatrixType MatrixType; 00435 typedef PastixBase<PastixLU<MatrixType> > Base; 00436 typedef typename Base::ColSpMatrix ColSpMatrix; 00437 typedef typename MatrixType::Index Index; 00438 00439 public: 00440 PastixLU() : Base() 00441 { 00442 init(); 00443 } 00444 00445 PastixLU(const MatrixType& matrix):Base() 00446 { 00447 init(); 00448 compute(matrix); 00449 } 00455 void compute (const MatrixType& matrix) 00456 { 00457 m_structureIsUptodate = false; 00458 ColSpMatrix temp; 00459 grabMatrix(matrix, temp); 00460 Base::compute(temp); 00461 } 00467 void analyzePattern(const MatrixType& matrix) 00468 { 00469 m_structureIsUptodate = false; 00470 ColSpMatrix temp; 00471 grabMatrix(matrix, temp); 00472 Base::analyzePattern(temp); 00473 } 00474 00480 void factorize(const MatrixType& matrix) 00481 { 00482 ColSpMatrix temp; 00483 grabMatrix(matrix, temp); 00484 Base::factorize(temp); 00485 } 00486 protected: 00487 00488 void init() 00489 { 00490 m_structureIsUptodate = false; 00491 m_iparm(IPARM_SYM) = API_SYM_NO; 00492 m_iparm(IPARM_FACTORIZATION) = API_FACT_LU; 00493 } 00494 00495 void grabMatrix(const MatrixType& matrix, ColSpMatrix& out) 00496 { 00497 if(IsStrSym) 00498 out = matrix; 00499 else 00500 { 00501 if(!m_structureIsUptodate) 00502 { 00503 // update the transposed structure 00504 m_transposedStructure = matrix.transpose(); 00505 00506 // Set the elements of the matrix to zero 00507 for (Index j=0; j<m_transposedStructure.outerSize(); ++j) 00508 for(typename ColSpMatrix::InnerIterator it(m_transposedStructure, j); it; ++it) 00509 it.valueRef() = 0.0; 00510 00511 m_structureIsUptodate = true; 00512 } 00513 00514 out = m_transposedStructure + matrix; 00515 } 00516 internal::c_to_fortran_numbering(out); 00517 } 00518 00519 using Base::m_iparm; 00520 using Base::m_dparm; 00521 00522 ColSpMatrix m_transposedStructure; 00523 bool m_structureIsUptodate; 00524 }; 00525 00540 template<typename _MatrixType, int _UpLo> 00541 class PastixLLT : public PastixBase< PastixLLT<_MatrixType, _UpLo> > 00542 { 00543 public: 00544 typedef _MatrixType MatrixType; 00545 typedef PastixBase<PastixLLT<MatrixType, _UpLo> > Base; 00546 typedef typename Base::ColSpMatrix ColSpMatrix; 00547 00548 public: 00549 enum { UpLo = _UpLo }; 00550 PastixLLT() : Base() 00551 { 00552 init(); 00553 } 00554 00555 PastixLLT(const MatrixType& matrix):Base() 00556 { 00557 init(); 00558 compute(matrix); 00559 } 00560 00564 void compute (const MatrixType& matrix) 00565 { 00566 ColSpMatrix temp; 00567 grabMatrix(matrix, temp); 00568 Base::compute(temp); 00569 } 00570 00575 void analyzePattern(const MatrixType& matrix) 00576 { 00577 ColSpMatrix temp; 00578 grabMatrix(matrix, temp); 00579 Base::analyzePattern(temp); 00580 } 00584 void factorize(const MatrixType& matrix) 00585 { 00586 ColSpMatrix temp; 00587 grabMatrix(matrix, temp); 00588 Base::factorize(temp); 00589 } 00590 protected: 00591 using Base::m_iparm; 00592 00593 void init() 00594 { 00595 m_iparm(IPARM_SYM) = API_SYM_YES; 00596 m_iparm(IPARM_FACTORIZATION) = API_FACT_LLT; 00597 } 00598 00599 void grabMatrix(const MatrixType& matrix, ColSpMatrix& out) 00600 { 00601 // Pastix supports only lower, column-major matrices 00602 out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>(); 00603 internal::c_to_fortran_numbering(out); 00604 } 00605 }; 00606 00621 template<typename _MatrixType, int _UpLo> 00622 class PastixLDLT : public PastixBase< PastixLDLT<_MatrixType, _UpLo> > 00623 { 00624 public: 00625 typedef _MatrixType MatrixType; 00626 typedef PastixBase<PastixLDLT<MatrixType, _UpLo> > Base; 00627 typedef typename Base::ColSpMatrix ColSpMatrix; 00628 00629 public: 00630 enum { UpLo = _UpLo }; 00631 PastixLDLT():Base() 00632 { 00633 init(); 00634 } 00635 00636 PastixLDLT(const MatrixType& matrix):Base() 00637 { 00638 init(); 00639 compute(matrix); 00640 } 00641 00645 void compute (const MatrixType& matrix) 00646 { 00647 ColSpMatrix temp; 00648 grabMatrix(matrix, temp); 00649 Base::compute(temp); 00650 } 00651 00656 void analyzePattern(const MatrixType& matrix) 00657 { 00658 ColSpMatrix temp; 00659 grabMatrix(matrix, temp); 00660 Base::analyzePattern(temp); 00661 } 00665 void factorize(const MatrixType& matrix) 00666 { 00667 ColSpMatrix temp; 00668 grabMatrix(matrix, temp); 00669 Base::factorize(temp); 00670 } 00671 00672 protected: 00673 using Base::m_iparm; 00674 00675 void init() 00676 { 00677 m_iparm(IPARM_SYM) = API_SYM_YES; 00678 m_iparm(IPARM_FACTORIZATION) = API_FACT_LDLT; 00679 } 00680 00681 void grabMatrix(const MatrixType& matrix, ColSpMatrix& out) 00682 { 00683 // Pastix supports only lower, column-major matrices 00684 out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>(); 00685 internal::c_to_fortran_numbering(out); 00686 } 00687 }; 00688 00689 namespace internal { 00690 00691 template<typename _MatrixType, typename Rhs> 00692 struct solve_retval<PastixBase<_MatrixType>, Rhs> 00693 : solve_retval_base<PastixBase<_MatrixType>, Rhs> 00694 { 00695 typedef PastixBase<_MatrixType> Dec; 00696 EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs) 00697 00698 template<typename Dest> void evalTo(Dest& dst) const 00699 { 00700 dec()._solve(rhs(),dst); 00701 } 00702 }; 00703 00704 template<typename _MatrixType, typename Rhs> 00705 struct sparse_solve_retval<PastixBase<_MatrixType>, Rhs> 00706 : sparse_solve_retval_base<PastixBase<_MatrixType>, Rhs> 00707 { 00708 typedef PastixBase<_MatrixType> Dec; 00709 EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs) 00710 00711 template<typename Dest> void evalTo(Dest& dst) const 00712 { 00713 this->defaultEvalTo(dst); 00714 } 00715 }; 00716 00717 } // end namespace internal 00718 00719 } // end namespace Eigen 00720 00721 #endif