$treeview $search $mathjax
Eigen-unsupported
3.2.5
$projectbrief
|
$projectbrief
|
$searchbox |
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_MATRIX_SQUARE_ROOT 00011 #define EIGEN_MATRIX_SQUARE_ROOT 00012 00013 namespace Eigen { 00014 00026 template <typename MatrixType> 00027 class MatrixSquareRootQuasiTriangular 00028 { 00029 public: 00030 00039 MatrixSquareRootQuasiTriangular(const MatrixType& A) 00040 : m_A(A) 00041 { 00042 eigen_assert(A.rows() == A.cols()); 00043 } 00044 00053 template <typename ResultType> void compute(ResultType &result); 00054 00055 private: 00056 typedef typename MatrixType::Index Index; 00057 typedef typename MatrixType::Scalar Scalar; 00058 00059 void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); 00060 void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); 00061 void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i); 00062 void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 00063 typename MatrixType::Index i, typename MatrixType::Index j); 00064 void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 00065 typename MatrixType::Index i, typename MatrixType::Index j); 00066 void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 00067 typename MatrixType::Index i, typename MatrixType::Index j); 00068 void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 00069 typename MatrixType::Index i, typename MatrixType::Index j); 00070 00071 template <typename SmallMatrixType> 00072 static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, 00073 const SmallMatrixType& B, const SmallMatrixType& C); 00074 00075 const MatrixType& m_A; 00076 }; 00077 00078 template <typename MatrixType> 00079 template <typename ResultType> 00080 void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result) 00081 { 00082 result.resize(m_A.rows(), m_A.cols()); 00083 computeDiagonalPartOfSqrt(result, m_A); 00084 computeOffDiagonalPartOfSqrt(result, m_A); 00085 } 00086 00087 // pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size 00088 // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T 00089 template <typename MatrixType> 00090 void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT, 00091 const MatrixType& T) 00092 { 00093 using std::sqrt; 00094 const Index size = m_A.rows(); 00095 for (Index i = 0; i < size; i++) { 00096 if (i == size - 1 || T.coeff(i+1, i) == 0) { 00097 eigen_assert(T(i,i) >= 0); 00098 sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i)); 00099 } 00100 else { 00101 compute2x2diagonalBlock(sqrtT, T, i); 00102 ++i; 00103 } 00104 } 00105 } 00106 00107 // pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T. 00108 // post: sqrtT is the square root of T. 00109 template <typename MatrixType> 00110 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, 00111 const MatrixType& T) 00112 { 00113 const Index size = m_A.rows(); 00114 for (Index j = 1; j < size; j++) { 00115 if (T.coeff(j, j-1) != 0) // if T(j-1:j, j-1:j) is a 2-by-2 block 00116 continue; 00117 for (Index i = j-1; i >= 0; i--) { 00118 if (i > 0 && T.coeff(i, i-1) != 0) // if T(i-1:i, i-1:i) is a 2-by-2 block 00119 continue; 00120 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0); 00121 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0); 00122 if (iBlockIs2x2 && jBlockIs2x2) 00123 compute2x2offDiagonalBlock(sqrtT, T, i, j); 00124 else if (iBlockIs2x2 && !jBlockIs2x2) 00125 compute2x1offDiagonalBlock(sqrtT, T, i, j); 00126 else if (!iBlockIs2x2 && jBlockIs2x2) 00127 compute1x2offDiagonalBlock(sqrtT, T, i, j); 00128 else if (!iBlockIs2x2 && !jBlockIs2x2) 00129 compute1x1offDiagonalBlock(sqrtT, T, i, j); 00130 } 00131 } 00132 } 00133 00134 // pre: T.block(i,i,2,2) has complex conjugate eigenvalues 00135 // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2) 00136 template <typename MatrixType> 00137 void MatrixSquareRootQuasiTriangular<MatrixType> 00138 ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i) 00139 { 00140 // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere 00141 // in EigenSolver. If we expose it, we could call it directly from here. 00142 Matrix<Scalar,2,2> block = T.template block<2,2>(i,i); 00143 EigenSolver<Matrix<Scalar,2,2> > es(block); 00144 sqrtT.template block<2,2>(i,i) 00145 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real(); 00146 } 00147 00148 // pre: block structure of T is such that (i,j) is a 1x1 block, 00149 // all blocks of sqrtT to left of and below (i,j) are correct 00150 // post: sqrtT(i,j) has the correct value 00151 template <typename MatrixType> 00152 void MatrixSquareRootQuasiTriangular<MatrixType> 00153 ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 00154 typename MatrixType::Index i, typename MatrixType::Index j) 00155 { 00156 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value(); 00157 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j)); 00158 } 00159 00160 // similar to compute1x1offDiagonalBlock() 00161 template <typename MatrixType> 00162 void MatrixSquareRootQuasiTriangular<MatrixType> 00163 ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 00164 typename MatrixType::Index i, typename MatrixType::Index j) 00165 { 00166 Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j); 00167 if (j-i > 1) 00168 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2); 00169 Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity(); 00170 A += sqrtT.template block<2,2>(j,j).transpose(); 00171 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose()); 00172 } 00173 00174 // similar to compute1x1offDiagonalBlock() 00175 template <typename MatrixType> 00176 void MatrixSquareRootQuasiTriangular<MatrixType> 00177 ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 00178 typename MatrixType::Index i, typename MatrixType::Index j) 00179 { 00180 Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j); 00181 if (j-i > 2) 00182 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1); 00183 Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity(); 00184 A += sqrtT.template block<2,2>(i,i); 00185 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs); 00186 } 00187 00188 // similar to compute1x1offDiagonalBlock() 00189 template <typename MatrixType> 00190 void MatrixSquareRootQuasiTriangular<MatrixType> 00191 ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 00192 typename MatrixType::Index i, typename MatrixType::Index j) 00193 { 00194 Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i); 00195 Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j); 00196 Matrix<Scalar,2,2> C = T.template block<2,2>(i,j); 00197 if (j-i > 2) 00198 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2); 00199 Matrix<Scalar,2,2> X; 00200 solveAuxiliaryEquation(X, A, B, C); 00201 sqrtT.template block<2,2>(i,j) = X; 00202 } 00203 00204 // solves the equation A X + X B = C where all matrices are 2-by-2 00205 template <typename MatrixType> 00206 template <typename SmallMatrixType> 00207 void MatrixSquareRootQuasiTriangular<MatrixType> 00208 ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, 00209 const SmallMatrixType& B, const SmallMatrixType& C) 00210 { 00211 EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value), 00212 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); 00213 00214 Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero(); 00215 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0); 00216 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1); 00217 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0); 00218 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1); 00219 coeffMatrix.coeffRef(0,1) = B.coeff(1,0); 00220 coeffMatrix.coeffRef(0,2) = A.coeff(0,1); 00221 coeffMatrix.coeffRef(1,0) = B.coeff(0,1); 00222 coeffMatrix.coeffRef(1,3) = A.coeff(0,1); 00223 coeffMatrix.coeffRef(2,0) = A.coeff(1,0); 00224 coeffMatrix.coeffRef(2,3) = B.coeff(1,0); 00225 coeffMatrix.coeffRef(3,1) = A.coeff(1,0); 00226 coeffMatrix.coeffRef(3,2) = B.coeff(0,1); 00227 00228 Matrix<Scalar,4,1> rhs; 00229 rhs.coeffRef(0) = C.coeff(0,0); 00230 rhs.coeffRef(1) = C.coeff(0,1); 00231 rhs.coeffRef(2) = C.coeff(1,0); 00232 rhs.coeffRef(3) = C.coeff(1,1); 00233 00234 Matrix<Scalar,4,1> result; 00235 result = coeffMatrix.fullPivLu().solve(rhs); 00236 00237 X.coeffRef(0,0) = result.coeff(0); 00238 X.coeffRef(0,1) = result.coeff(1); 00239 X.coeffRef(1,0) = result.coeff(2); 00240 X.coeffRef(1,1) = result.coeff(3); 00241 } 00242 00243 00255 template <typename MatrixType> 00256 class MatrixSquareRootTriangular 00257 { 00258 public: 00259 MatrixSquareRootTriangular(const MatrixType& A) 00260 : m_A(A) 00261 { 00262 eigen_assert(A.rows() == A.cols()); 00263 } 00264 00274 template <typename ResultType> void compute(ResultType &result); 00275 00276 private: 00277 const MatrixType& m_A; 00278 }; 00279 00280 template <typename MatrixType> 00281 template <typename ResultType> 00282 void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result) 00283 { 00284 using std::sqrt; 00285 00286 // Compute square root of m_A and store it in upper triangular part of result 00287 // This uses that the square root of triangular matrices can be computed directly. 00288 result.resize(m_A.rows(), m_A.cols()); 00289 typedef typename MatrixType::Index Index; 00290 for (Index i = 0; i < m_A.rows(); i++) { 00291 result.coeffRef(i,i) = sqrt(m_A.coeff(i,i)); 00292 } 00293 for (Index j = 1; j < m_A.cols(); j++) { 00294 for (Index i = j-1; i >= 0; i--) { 00295 typedef typename MatrixType::Scalar Scalar; 00296 // if i = j-1, then segment has length 0 so tmp = 0 00297 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value(); 00298 // denominator may be zero if original matrix is singular 00299 result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); 00300 } 00301 } 00302 } 00303 00304 00312 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex> 00313 class MatrixSquareRoot 00314 { 00315 public: 00316 00324 MatrixSquareRoot(const MatrixType& A); 00325 00333 template <typename ResultType> void compute(ResultType &result); 00334 }; 00335 00336 00337 // ********** Partial specialization for real matrices ********** 00338 00339 template <typename MatrixType> 00340 class MatrixSquareRoot<MatrixType, 0> 00341 { 00342 public: 00343 00344 MatrixSquareRoot(const MatrixType& A) 00345 : m_A(A) 00346 { 00347 eigen_assert(A.rows() == A.cols()); 00348 } 00349 00350 template <typename ResultType> void compute(ResultType &result) 00351 { 00352 // Compute Schur decomposition of m_A 00353 const RealSchur<MatrixType> schurOfA(m_A); 00354 const MatrixType& T = schurOfA.matrixT(); 00355 const MatrixType& U = schurOfA.matrixU(); 00356 00357 // Compute square root of T 00358 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols()); 00359 MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT); 00360 00361 // Compute square root of m_A 00362 result = U * sqrtT * U.adjoint(); 00363 } 00364 00365 private: 00366 const MatrixType& m_A; 00367 }; 00368 00369 00370 // ********** Partial specialization for complex matrices ********** 00371 00372 template <typename MatrixType> 00373 class MatrixSquareRoot<MatrixType, 1> 00374 { 00375 public: 00376 00377 MatrixSquareRoot(const MatrixType& A) 00378 : m_A(A) 00379 { 00380 eigen_assert(A.rows() == A.cols()); 00381 } 00382 00383 template <typename ResultType> void compute(ResultType &result) 00384 { 00385 // Compute Schur decomposition of m_A 00386 const ComplexSchur<MatrixType> schurOfA(m_A); 00387 const MatrixType& T = schurOfA.matrixT(); 00388 const MatrixType& U = schurOfA.matrixU(); 00389 00390 // Compute square root of T 00391 MatrixType sqrtT; 00392 MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT); 00393 00394 // Compute square root of m_A 00395 result = U * (sqrtT.template triangularView<Upper>() * U.adjoint()); 00396 } 00397 00398 private: 00399 const MatrixType& m_A; 00400 }; 00401 00402 00415 template<typename Derived> class MatrixSquareRootReturnValue 00416 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> > 00417 { 00418 typedef typename Derived::Index Index; 00419 public: 00425 MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { } 00426 00432 template <typename ResultType> 00433 inline void evalTo(ResultType& result) const 00434 { 00435 const typename Derived::PlainObject srcEvaluated = m_src.eval(); 00436 MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated); 00437 me.compute(result); 00438 } 00439 00440 Index rows() const { return m_src.rows(); } 00441 Index cols() const { return m_src.cols(); } 00442 00443 protected: 00444 const Derived& m_src; 00445 private: 00446 MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&); 00447 }; 00448 00449 namespace internal { 00450 template<typename Derived> 00451 struct traits<MatrixSquareRootReturnValue<Derived> > 00452 { 00453 typedef typename Derived::PlainObject ReturnType; 00454 }; 00455 } 00456 00457 template <typename Derived> 00458 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const 00459 { 00460 eigen_assert(rows() == cols()); 00461 return MatrixSquareRootReturnValue<Derived>(derived()); 00462 } 00463 00464 } // end namespace Eigen 00465 00466 #endif // EIGEN_MATRIX_FUNCTION