$treeview $search $mathjax
Eigen
3.2.5
$projectbrief
|
$projectbrief
|
$searchbox |
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_SPARSEDENSEPRODUCT_H 00011 #define EIGEN_SPARSEDENSEPRODUCT_H 00012 00013 namespace Eigen { 00014 00015 template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType 00016 { 00017 typedef SparseTimeDenseProduct<Lhs,Rhs> Type; 00018 }; 00019 00020 template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1> 00021 { 00022 typedef typename internal::conditional< 00023 Lhs::IsRowMajor, 00024 SparseDenseOuterProduct<Rhs,Lhs,true>, 00025 SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type; 00026 }; 00027 00028 template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType 00029 { 00030 typedef DenseTimeSparseProduct<Lhs,Rhs> Type; 00031 }; 00032 00033 template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1> 00034 { 00035 typedef typename internal::conditional< 00036 Rhs::IsRowMajor, 00037 SparseDenseOuterProduct<Rhs,Lhs,true>, 00038 SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type; 00039 }; 00040 00041 namespace internal { 00042 00043 template<typename Lhs, typename Rhs, bool Tr> 00044 struct traits<SparseDenseOuterProduct<Lhs,Rhs,Tr> > 00045 { 00046 typedef Sparse StorageKind; 00047 typedef typename scalar_product_traits<typename traits<Lhs>::Scalar, 00048 typename traits<Rhs>::Scalar>::ReturnType Scalar; 00049 typedef typename Lhs::Index Index; 00050 typedef typename Lhs::Nested LhsNested; 00051 typedef typename Rhs::Nested RhsNested; 00052 typedef typename remove_all<LhsNested>::type _LhsNested; 00053 typedef typename remove_all<RhsNested>::type _RhsNested; 00054 00055 enum { 00056 LhsCoeffReadCost = traits<_LhsNested>::CoeffReadCost, 00057 RhsCoeffReadCost = traits<_RhsNested>::CoeffReadCost, 00058 00059 RowsAtCompileTime = Tr ? int(traits<Rhs>::RowsAtCompileTime) : int(traits<Lhs>::RowsAtCompileTime), 00060 ColsAtCompileTime = Tr ? int(traits<Lhs>::ColsAtCompileTime) : int(traits<Rhs>::ColsAtCompileTime), 00061 MaxRowsAtCompileTime = Tr ? int(traits<Rhs>::MaxRowsAtCompileTime) : int(traits<Lhs>::MaxRowsAtCompileTime), 00062 MaxColsAtCompileTime = Tr ? int(traits<Lhs>::MaxColsAtCompileTime) : int(traits<Rhs>::MaxColsAtCompileTime), 00063 00064 Flags = Tr ? RowMajorBit : 0, 00065 00066 CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits<Scalar>::MulCost 00067 }; 00068 }; 00069 00070 } // end namespace internal 00071 00072 template<typename Lhs, typename Rhs, bool Tr> 00073 class SparseDenseOuterProduct 00074 : public SparseMatrixBase<SparseDenseOuterProduct<Lhs,Rhs,Tr> > 00075 { 00076 public: 00077 00078 typedef SparseMatrixBase<SparseDenseOuterProduct> Base; 00079 EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct) 00080 typedef internal::traits<SparseDenseOuterProduct> Traits; 00081 00082 private: 00083 00084 typedef typename Traits::LhsNested LhsNested; 00085 typedef typename Traits::RhsNested RhsNested; 00086 typedef typename Traits::_LhsNested _LhsNested; 00087 typedef typename Traits::_RhsNested _RhsNested; 00088 00089 public: 00090 00091 class InnerIterator; 00092 00093 EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs) 00094 : m_lhs(lhs), m_rhs(rhs) 00095 { 00096 EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE); 00097 } 00098 00099 EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs) 00100 : m_lhs(lhs), m_rhs(rhs) 00101 { 00102 EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE); 00103 } 00104 00105 EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); } 00106 EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); } 00107 00108 EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } 00109 EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } 00110 00111 protected: 00112 LhsNested m_lhs; 00113 RhsNested m_rhs; 00114 }; 00115 00116 template<typename Lhs, typename Rhs, bool Transpose> 00117 class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNested::InnerIterator 00118 { 00119 typedef typename _LhsNested::InnerIterator Base; 00120 typedef typename SparseDenseOuterProduct::Index Index; 00121 public: 00122 EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer) 00123 : Base(prod.lhs(), 0), m_outer(outer), m_factor(get(prod.rhs(), outer, typename internal::traits<Rhs>::StorageKind() )) 00124 { } 00125 00126 inline Index outer() const { return m_outer; } 00127 inline Index row() const { return Transpose ? m_outer : Base::index(); } 00128 inline Index col() const { return Transpose ? Base::index() : m_outer; } 00129 00130 inline Scalar value() const { return Base::value() * m_factor; } 00131 00132 protected: 00133 static Scalar get(const _RhsNested &rhs, Index outer, Dense = Dense()) 00134 { 00135 return rhs.coeff(outer); 00136 } 00137 00138 static Scalar get(const _RhsNested &rhs, Index outer, Sparse = Sparse()) 00139 { 00140 typename Traits::_RhsNested::InnerIterator it(rhs, outer); 00141 if (it && it.index()==0) 00142 return it.value(); 00143 00144 return Scalar(0); 00145 } 00146 00147 Index m_outer; 00148 Scalar m_factor; 00149 }; 00150 00151 namespace internal { 00152 template<typename Lhs, typename Rhs> 00153 struct traits<SparseTimeDenseProduct<Lhs,Rhs> > 00154 : traits<ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> > 00155 { 00156 typedef Dense StorageKind; 00157 typedef MatrixXpr XprKind; 00158 }; 00159 00160 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, 00161 int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor, 00162 bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1> 00163 struct sparse_time_dense_product_impl; 00164 00165 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 00166 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true> 00167 { 00168 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 00169 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 00170 typedef typename internal::remove_all<DenseResType>::type Res; 00171 typedef typename Lhs::Index Index; 00172 typedef typename Lhs::InnerIterator LhsInnerIterator; 00173 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 00174 { 00175 for(Index c=0; c<rhs.cols(); ++c) 00176 { 00177 Index n = lhs.outerSize(); 00178 for(Index j=0; j<n; ++j) 00179 { 00180 typename Res::Scalar tmp(0); 00181 for(LhsInnerIterator it(lhs,j); it ;++it) 00182 tmp += it.value() * rhs.coeff(it.index(),c); 00183 res.coeffRef(j,c) += alpha * tmp; 00184 } 00185 } 00186 } 00187 }; 00188 00189 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 00190 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true> 00191 { 00192 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 00193 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 00194 typedef typename internal::remove_all<DenseResType>::type Res; 00195 typedef typename Lhs::InnerIterator LhsInnerIterator; 00196 typedef typename Lhs::Index Index; 00197 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 00198 { 00199 for(Index c=0; c<rhs.cols(); ++c) 00200 { 00201 for(Index j=0; j<lhs.outerSize(); ++j) 00202 { 00203 typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c); 00204 for(LhsInnerIterator it(lhs,j); it ;++it) 00205 res.coeffRef(it.index(),c) += it.value() * rhs_j; 00206 } 00207 } 00208 } 00209 }; 00210 00211 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 00212 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false> 00213 { 00214 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 00215 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 00216 typedef typename internal::remove_all<DenseResType>::type Res; 00217 typedef typename Lhs::InnerIterator LhsInnerIterator; 00218 typedef typename Lhs::Index Index; 00219 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 00220 { 00221 for(Index j=0; j<lhs.outerSize(); ++j) 00222 { 00223 typename Res::RowXpr res_j(res.row(j)); 00224 for(LhsInnerIterator it(lhs,j); it ;++it) 00225 res_j += (alpha*it.value()) * rhs.row(it.index()); 00226 } 00227 } 00228 }; 00229 00230 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 00231 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false> 00232 { 00233 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 00234 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 00235 typedef typename internal::remove_all<DenseResType>::type Res; 00236 typedef typename Lhs::InnerIterator LhsInnerIterator; 00237 typedef typename Lhs::Index Index; 00238 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 00239 { 00240 for(Index j=0; j<lhs.outerSize(); ++j) 00241 { 00242 typename Rhs::ConstRowXpr rhs_j(rhs.row(j)); 00243 for(LhsInnerIterator it(lhs,j); it ;++it) 00244 res.row(it.index()) += (alpha*it.value()) * rhs_j; 00245 } 00246 } 00247 }; 00248 00249 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType> 00250 inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) 00251 { 00252 sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha); 00253 } 00254 00255 } // end namespace internal 00256 00257 template<typename Lhs, typename Rhs> 00258 class SparseTimeDenseProduct 00259 : public ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> 00260 { 00261 public: 00262 EIGEN_PRODUCT_PUBLIC_INTERFACE(SparseTimeDenseProduct) 00263 00264 SparseTimeDenseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) 00265 {} 00266 00267 template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const 00268 { 00269 internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha); 00270 } 00271 00272 private: 00273 SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&); 00274 }; 00275 00276 00277 // dense = dense * sparse 00278 namespace internal { 00279 template<typename Lhs, typename Rhs> 00280 struct traits<DenseTimeSparseProduct<Lhs,Rhs> > 00281 : traits<ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> > 00282 { 00283 typedef Dense StorageKind; 00284 }; 00285 } // end namespace internal 00286 00287 template<typename Lhs, typename Rhs> 00288 class DenseTimeSparseProduct 00289 : public ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> 00290 { 00291 public: 00292 EIGEN_PRODUCT_PUBLIC_INTERFACE(DenseTimeSparseProduct) 00293 00294 DenseTimeSparseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) 00295 {} 00296 00297 template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const 00298 { 00299 Transpose<const _LhsNested> lhs_t(m_lhs); 00300 Transpose<const _RhsNested> rhs_t(m_rhs); 00301 Transpose<Dest> dest_t(dest); 00302 internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha); 00303 } 00304 00305 private: 00306 DenseTimeSparseProduct& operator=(const DenseTimeSparseProduct&); 00307 }; 00308 00309 } // end namespace Eigen 00310 00311 #endif // EIGEN_SPARSEDENSEPRODUCT_H