Another set of UI messages fixes and tweaks! No functional changes.
[blender.git] / extern / Eigen3 / Eigen / src / Core / products / TriangularMatrixVector.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // Eigen is free software; you can redistribute it and/or
7 // modify it under the terms of the GNU Lesser General Public
8 // License as published by the Free Software Foundation; either
9 // version 3 of the License, or (at your option) any later version.
10 //
11 // Alternatively, you can redistribute it and/or
12 // modify it under the terms of the GNU General Public License as
13 // published by the Free Software Foundation; either version 2 of
14 // the License, or (at your option) any later version.
15 //
16 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
17 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
18 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
19 // GNU General Public License for more details.
20 //
21 // You should have received a copy of the GNU Lesser General Public
22 // License and a copy of the GNU General Public License along with
23 // Eigen. If not, see <http://www.gnu.org/licenses/>.
24
25 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
26 #define EIGEN_TRIANGULARMATRIXVECTOR_H
27
28 namespace internal {
29
30 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder>
31 struct product_triangular_matrix_vector;
32
33 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs>
34 struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor>
35 {
36   typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
37   enum {
38     IsLower = ((Mode&Lower)==Lower),
39     HasUnitDiag = (Mode & UnitDiag)==UnitDiag
40   };
41   static EIGEN_DONT_INLINE  void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
42                                      const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
43   {
44     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
45
46     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
47     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
48     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
49     
50     typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
51     const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
52     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
53
54     typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
55     ResMap res(_res,rows);
56
57     for (Index pi=0; pi<cols; pi+=PanelWidth)
58     {
59       Index actualPanelWidth = (std::min)(PanelWidth, cols-pi);
60       for (Index k=0; k<actualPanelWidth; ++k)
61       {
62         Index i = pi + k;
63         Index s = IsLower ? (HasUnitDiag ? i+1 : i ) : pi;
64         Index r = IsLower ? actualPanelWidth-k : k+1;
65         if ((!HasUnitDiag) || (--r)>0)
66           res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
67         if (HasUnitDiag)
68           res.coeffRef(i) += alpha * cjRhs.coeff(i);
69       }
70       Index r = IsLower ? cols - pi - actualPanelWidth : pi;
71       if (r>0)
72       {
73         Index s = IsLower ? pi+actualPanelWidth : 0;
74         general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
75             r, actualPanelWidth,
76             &lhs.coeffRef(s,pi), lhsStride,
77             &rhs.coeffRef(pi), rhsIncr,
78             &res.coeffRef(s), resIncr, alpha);
79       }
80     }
81   }
82 };
83
84 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs>
85 struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor>
86 {
87   typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
88   enum {
89     IsLower = ((Mode&Lower)==Lower),
90     HasUnitDiag = (Mode & UnitDiag)==UnitDiag
91   };
92   static void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
93                   const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
94   {
95     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
96
97     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
98     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
99     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
100
101     typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
102     const RhsMap rhs(_rhs,cols);
103     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
104
105     typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
106     ResMap res(_res,rows,InnerStride<>(resIncr));
107     
108     for (Index pi=0; pi<cols; pi+=PanelWidth)
109     {
110       Index actualPanelWidth = (std::min)(PanelWidth, cols-pi);
111       for (Index k=0; k<actualPanelWidth; ++k)
112       {
113         Index i = pi + k;
114         Index s = IsLower ? pi  : (HasUnitDiag ? i+1 : i);
115         Index r = IsLower ? k+1 : actualPanelWidth-k;
116         if ((!HasUnitDiag) || (--r)>0)
117           res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
118         if (HasUnitDiag)
119           res.coeffRef(i) += alpha * cjRhs.coeff(i);
120       }
121       Index r = IsLower ? pi : cols - pi - actualPanelWidth;
122       if (r>0)
123       {
124         Index s = IsLower ? 0 : pi + actualPanelWidth;
125         general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
126             actualPanelWidth, r,
127             &lhs.coeffRef(pi,s), lhsStride,
128             &rhs.coeffRef(s), rhsIncr,
129             &res.coeffRef(pi), resIncr, alpha);
130       }
131     }
132   }
133 };
134
135 /***************************************************************************
136 * Wrapper to product_triangular_vector
137 ***************************************************************************/
138
139 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
140 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
141  : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
142 {};
143
144 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
145 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
146  : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
147 {};
148
149
150 template<int StorageOrder>
151 struct trmv_selector;
152
153 } // end namespace internal
154
155 template<int Mode, typename Lhs, typename Rhs>
156 struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
157   : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
158 {
159   EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
160
161   TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
162
163   template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
164   {
165     eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
166   
167     internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
168   }
169 };
170
171 template<int Mode, typename Lhs, typename Rhs>
172 struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
173   : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
174 {
175   EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
176
177   TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
178
179   template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
180   {
181     eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
182
183     typedef TriangularProduct<(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
184     Transpose<Dest> dstT(dst);
185     internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
186       TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
187   }
188 };
189
190 namespace internal {
191
192 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
193   
194 template<> struct trmv_selector<ColMajor>
195 {
196   template<int Mode, typename Lhs, typename Rhs, typename Dest>
197   static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
198   {
199     typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
200     typedef typename ProductType::Index Index;
201     typedef typename ProductType::LhsScalar   LhsScalar;
202     typedef typename ProductType::RhsScalar   RhsScalar;
203     typedef typename ProductType::Scalar      ResScalar;
204     typedef typename ProductType::RealScalar  RealScalar;
205     typedef typename ProductType::ActualLhsType ActualLhsType;
206     typedef typename ProductType::ActualRhsType ActualRhsType;
207     typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
208     typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
209     typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
210
211     const ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs());
212     const ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs());
213
214     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
215                                   * RhsBlasTraits::extractScalarFactor(prod.rhs());
216
217     enum {
218       // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
219       // on, the other hand it is good for the cache to pack the vector anyways...
220       EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
221       ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
222       MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
223     };
224
225     gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
226
227     bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
228     bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
229     
230     RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
231
232     ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
233                                                   evalToDest ? dest.data() : static_dest.data());
234
235     if(!evalToDest)
236     {
237       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
238       int size = dest.size();
239       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
240       #endif
241       if(!alphaIsCompatible)
242       {
243         MappedDest(actualDestPtr, dest.size()).setZero();
244         compatibleAlpha = RhsScalar(1);
245       }
246       else
247         MappedDest(actualDestPtr, dest.size()) = dest;
248     }
249     
250     internal::product_triangular_matrix_vector
251       <Index,Mode,
252        LhsScalar, LhsBlasTraits::NeedToConjugate,
253        RhsScalar, RhsBlasTraits::NeedToConjugate,
254        ColMajor>
255       ::run(actualLhs.rows(),actualLhs.cols(),
256             actualLhs.data(),actualLhs.outerStride(),
257             actualRhs.data(),actualRhs.innerStride(),
258             actualDestPtr,1,compatibleAlpha);
259
260     if (!evalToDest)
261     {
262       if(!alphaIsCompatible)
263         dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
264       else
265         dest = MappedDest(actualDestPtr, dest.size());
266     }
267   }
268 };
269
270 template<> struct trmv_selector<RowMajor>
271 {
272   template<int Mode, typename Lhs, typename Rhs, typename Dest>
273   static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
274   {
275     typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
276     typedef typename ProductType::LhsScalar LhsScalar;
277     typedef typename ProductType::RhsScalar RhsScalar;
278     typedef typename ProductType::Scalar    ResScalar;
279     typedef typename ProductType::Index Index;
280     typedef typename ProductType::ActualLhsType ActualLhsType;
281     typedef typename ProductType::ActualRhsType ActualRhsType;
282     typedef typename ProductType::_ActualRhsType _ActualRhsType;
283     typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
284     typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
285
286     typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
287     typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
288
289     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
290                                   * RhsBlasTraits::extractScalarFactor(prod.rhs());
291
292     enum {
293       DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
294     };
295
296     gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
297
298     ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
299         DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
300
301     if(!DirectlyUseRhs)
302     {
303       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
304       int size = actualRhs.size();
305       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
306       #endif
307       Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
308     }
309     
310     internal::product_triangular_matrix_vector
311       <Index,Mode,
312        LhsScalar, LhsBlasTraits::NeedToConjugate,
313        RhsScalar, RhsBlasTraits::NeedToConjugate,
314        RowMajor>
315       ::run(actualLhs.rows(),actualLhs.cols(),
316             actualLhs.data(),actualLhs.outerStride(),
317             actualRhsPtr,1,
318             dest.data(),dest.innerStride(),
319             actualAlpha);
320   }
321 };
322
323 } // end namespace internal
324
325 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H