$treeview $search $mathjax
Eigen-unsupported
3.2.5
$projectbrief
|
$projectbrief
|
$searchbox |
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2011 Kolja Brix <brix@igpm.rwth-aachen.de> 00005 // Copyright (C) 2011 Andreas Platen <andiplaten@gmx.de> 00006 // Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net> 00007 // 00008 // This Source Code Form is subject to the terms of the Mozilla 00009 // Public License v. 2.0. If a copy of the MPL was not distributed 00010 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00011 00012 #ifndef KRONECKER_TENSOR_PRODUCT_H 00013 #define KRONECKER_TENSOR_PRODUCT_H 00014 00015 namespace Eigen { 00016 00017 template<typename Scalar, int Options, typename Index> class SparseMatrix; 00018 00029 template<typename Lhs, typename Rhs> 00030 class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> > 00031 { 00032 private: 00033 typedef ReturnByValue<KroneckerProduct> Base; 00034 typedef typename Base::Scalar Scalar; 00035 typedef typename Base::Index Index; 00036 00037 public: 00039 KroneckerProduct(const Lhs& A, const Rhs& B) 00040 : m_A(A), m_B(B) 00041 {} 00042 00044 template<typename Dest> void evalTo(Dest& dst) const; 00045 00046 inline Index rows() const { return m_A.rows() * m_B.rows(); } 00047 inline Index cols() const { return m_A.cols() * m_B.cols(); } 00048 00049 Scalar coeff(Index row, Index col) const 00050 { 00051 return m_A.coeff(row / m_B.rows(), col / m_B.cols()) * 00052 m_B.coeff(row % m_B.rows(), col % m_B.cols()); 00053 } 00054 00055 Scalar coeff(Index i) const 00056 { 00057 EIGEN_STATIC_ASSERT_VECTOR_ONLY(KroneckerProduct); 00058 return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size()); 00059 } 00060 00061 private: 00062 typename Lhs::Nested m_A; 00063 typename Rhs::Nested m_B; 00064 }; 00065 00079 template<typename Lhs, typename Rhs> 00080 class KroneckerProductSparse : public EigenBase<KroneckerProductSparse<Lhs,Rhs> > 00081 { 00082 private: 00083 typedef typename internal::traits<KroneckerProductSparse>::Index Index; 00084 00085 public: 00087 KroneckerProductSparse(const Lhs& A, const Rhs& B) 00088 : m_A(A), m_B(B) 00089 {} 00090 00092 template<typename Dest> void evalTo(Dest& dst) const; 00093 00094 inline Index rows() const { return m_A.rows() * m_B.rows(); } 00095 inline Index cols() const { return m_A.cols() * m_B.cols(); } 00096 00097 template<typename Scalar, int Options, typename Index> 00098 operator SparseMatrix<Scalar, Options, Index>() 00099 { 00100 SparseMatrix<Scalar, Options, Index> result; 00101 evalTo(result.derived()); 00102 return result; 00103 } 00104 00105 private: 00106 typename Lhs::Nested m_A; 00107 typename Rhs::Nested m_B; 00108 }; 00109 00110 template<typename Lhs, typename Rhs> 00111 template<typename Dest> 00112 void KroneckerProduct<Lhs,Rhs>::evalTo(Dest& dst) const 00113 { 00114 const int BlockRows = Rhs::RowsAtCompileTime, 00115 BlockCols = Rhs::ColsAtCompileTime; 00116 const Index Br = m_B.rows(), 00117 Bc = m_B.cols(); 00118 for (Index i=0; i < m_A.rows(); ++i) 00119 for (Index j=0; j < m_A.cols(); ++j) 00120 Block<Dest,BlockRows,BlockCols>(dst,i*Br,j*Bc,Br,Bc) = m_A.coeff(i,j) * m_B; 00121 } 00122 00123 template<typename Lhs, typename Rhs> 00124 template<typename Dest> 00125 void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const 00126 { 00127 const Index Br = m_B.rows(), 00128 Bc = m_B.cols(); 00129 dst.resize(rows(),cols()); 00130 dst.resizeNonZeros(0); 00131 dst.reserve(m_A.nonZeros() * m_B.nonZeros()); 00132 00133 for (Index kA=0; kA < m_A.outerSize(); ++kA) 00134 { 00135 for (Index kB=0; kB < m_B.outerSize(); ++kB) 00136 { 00137 for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA) 00138 { 00139 for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB) 00140 { 00141 const Index i = itA.row() * Br + itB.row(), 00142 j = itA.col() * Bc + itB.col(); 00143 dst.insert(i,j) = itA.value() * itB.value(); 00144 } 00145 } 00146 } 00147 } 00148 } 00149 00150 namespace internal { 00151 00152 template<typename _Lhs, typename _Rhs> 00153 struct traits<KroneckerProduct<_Lhs,_Rhs> > 00154 { 00155 typedef typename remove_all<_Lhs>::type Lhs; 00156 typedef typename remove_all<_Rhs>::type Rhs; 00157 typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; 00158 00159 enum { 00160 Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret, 00161 Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret, 00162 MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret, 00163 MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret, 00164 CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost 00165 }; 00166 00167 typedef Matrix<Scalar,Rows,Cols> ReturnType; 00168 }; 00169 00170 template<typename _Lhs, typename _Rhs> 00171 struct traits<KroneckerProductSparse<_Lhs,_Rhs> > 00172 { 00173 typedef MatrixXpr XprKind; 00174 typedef typename remove_all<_Lhs>::type Lhs; 00175 typedef typename remove_all<_Rhs>::type Rhs; 00176 typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; 00177 typedef typename promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind>::ret StorageKind; 00178 typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index; 00179 00180 enum { 00181 LhsFlags = Lhs::Flags, 00182 RhsFlags = Rhs::Flags, 00183 00184 RowsAtCompileTime = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret, 00185 ColsAtCompileTime = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret, 00186 MaxRowsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret, 00187 MaxColsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret, 00188 00189 EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit), 00190 RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), 00191 00192 Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) 00193 | EvalBeforeNestingBit | EvalBeforeAssigningBit, 00194 CoeffReadCost = Dynamic 00195 }; 00196 }; 00197 00198 } // end namespace internal 00199 00219 template<typename A, typename B> 00220 KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<B>& b) 00221 { 00222 return KroneckerProduct<A, B>(a.derived(), b.derived()); 00223 } 00224 00236 template<typename A, typename B> 00237 KroneckerProductSparse<A,B> kroneckerProduct(const EigenBase<A>& a, const EigenBase<B>& b) 00238 { 00239 return KroneckerProductSparse<A,B>(a.derived(), b.derived()); 00240 } 00241 00242 } // end namespace Eigen 00243 00244 #endif // KRONECKER_TENSOR_PRODUCT_H