$treeview $search $mathjax
Eigen
3.2.5
$projectbrief
|
$projectbrief
|
$searchbox |
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H 00011 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00017 00018 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major 00019 template<typename Lhs, typename Rhs, typename ResultType> 00020 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance) 00021 { 00022 // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res); 00023 00024 typedef typename remove_all<Lhs>::type::Scalar Scalar; 00025 typedef typename remove_all<Lhs>::type::Index Index; 00026 00027 // make sure to call innerSize/outerSize since we fake the storage order. 00028 Index rows = lhs.innerSize(); 00029 Index cols = rhs.outerSize(); 00030 //Index size = lhs.outerSize(); 00031 eigen_assert(lhs.outerSize() == rhs.innerSize()); 00032 00033 // allocate a temporary buffer 00034 AmbiVector<Scalar,Index> tempVector(rows); 00035 00036 // estimate the number of non zero entries 00037 // given a rhs column containing Y non zeros, we assume that the respective Y columns 00038 // of the lhs differs in average of one non zeros, thus the number of non zeros for 00039 // the product of a rhs column with the lhs is X+Y where X is the average number of non zero 00040 // per column of the lhs. 00041 // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs) 00042 Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros(); 00043 00044 // mimics a resizeByInnerOuter: 00045 if(ResultType::IsRowMajor) 00046 res.resize(cols, rows); 00047 else 00048 res.resize(rows, cols); 00049 00050 res.reserve(estimated_nnz_prod); 00051 double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols()); 00052 for (Index j=0; j<cols; ++j) 00053 { 00054 // FIXME: 00055 //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows()); 00056 // let's do a more accurate determination of the nnz ratio for the current column j of res 00057 tempVector.init(ratioColRes); 00058 tempVector.setZero(); 00059 for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt) 00060 { 00061 // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) 00062 tempVector.restart(); 00063 Scalar x = rhsIt.value(); 00064 for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt) 00065 { 00066 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; 00067 } 00068 } 00069 res.startVec(j); 00070 for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it) 00071 res.insertBackByOuterInner(j,it.index()) = it.value(); 00072 } 00073 res.finalize(); 00074 } 00075 00076 template<typename Lhs, typename Rhs, typename ResultType, 00077 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit, 00078 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit, 00079 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit> 00080 struct sparse_sparse_product_with_pruning_selector; 00081 00082 template<typename Lhs, typename Rhs, typename ResultType> 00083 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor> 00084 { 00085 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar; 00086 typedef typename ResultType::RealScalar RealScalar; 00087 00088 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 00089 { 00090 typename remove_all<ResultType>::type _res(res.rows(), res.cols()); 00091 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance); 00092 res.swap(_res); 00093 } 00094 }; 00095 00096 template<typename Lhs, typename Rhs, typename ResultType> 00097 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor> 00098 { 00099 typedef typename ResultType::RealScalar RealScalar; 00100 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 00101 { 00102 // we need a col-major matrix to hold the result 00103 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> SparseTemporaryType; 00104 SparseTemporaryType _res(res.rows(), res.cols()); 00105 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance); 00106 res = _res; 00107 } 00108 }; 00109 00110 template<typename Lhs, typename Rhs, typename ResultType> 00111 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor> 00112 { 00113 typedef typename ResultType::RealScalar RealScalar; 00114 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 00115 { 00116 // let's transpose the product to get a column x column product 00117 typename remove_all<ResultType>::type _res(res.rows(), res.cols()); 00118 internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance); 00119 res.swap(_res); 00120 } 00121 }; 00122 00123 template<typename Lhs, typename Rhs, typename ResultType> 00124 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> 00125 { 00126 typedef typename ResultType::RealScalar RealScalar; 00127 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 00128 { 00129 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs; 00130 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs; 00131 ColMajorMatrixLhs colLhs(lhs); 00132 ColMajorMatrixRhs colRhs(rhs); 00133 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance); 00134 00135 // let's transpose the product to get a column x column product 00136 // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; 00137 // SparseTemporaryType _res(res.cols(), res.rows()); 00138 // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res); 00139 // res = _res.transpose(); 00140 } 00141 }; 00142 00143 // NOTE the 2 others cases (col row *) must never occur since they are caught 00144 // by ProductReturnType which transforms it to (col col *) by evaluating rhs. 00145 00146 } // end namespace internal 00147 00148 } // end namespace Eigen 00149 00150 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H