Cycles: svn merge -r41225:41232 ^/trunk/blender
[blender.git] / extern / Eigen2 / Eigen / src / Sparse / UmfPackSupport.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra. Eigen itself is part of the KDE project.
3 //
4 // Copyright (C) 2008-2009 Gael Guennebaud <g.gael@free.fr>
5 //
6 // Eigen is free software; you can redistribute it and/or
7 // modify it under the terms of the GNU Lesser General Public
8 // License as published by the Free Software Foundation; either
9 // version 3 of the License, or (at your option) any later version.
10 //
11 // Alternatively, you can redistribute it and/or
12 // modify it under the terms of the GNU General Public License as
13 // published by the Free Software Foundation; either version 2 of
14 // the License, or (at your option) any later version.
15 //
16 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
17 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
18 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
19 // GNU General Public License for more details.
20 //
21 // You should have received a copy of the GNU Lesser General Public
22 // License and a copy of the GNU General Public License along with
23 // Eigen. If not, see <http://www.gnu.org/licenses/>.
24
25 #ifndef EIGEN_UMFPACKSUPPORT_H
26 #define EIGEN_UMFPACKSUPPORT_H
27
28 /* TODO extract L, extract U, compute det, etc... */
29
30 // generic double/complex<double> wrapper functions:
31
32 inline void umfpack_free_numeric(void **Numeric, double)
33 { umfpack_di_free_numeric(Numeric); }
34
35 inline void umfpack_free_numeric(void **Numeric, std::complex<double>)
36 { umfpack_zi_free_numeric(Numeric); }
37
38 inline void umfpack_free_symbolic(void **Symbolic, double)
39 { umfpack_di_free_symbolic(Symbolic); }
40
41 inline void umfpack_free_symbolic(void **Symbolic, std::complex<double>)
42 { umfpack_zi_free_symbolic(Symbolic); }
43
44 inline int umfpack_symbolic(int n_row,int n_col,
45                             const int Ap[], const int Ai[], const double Ax[], void **Symbolic,
46                             const double Control [UMFPACK_CONTROL], double Info [UMFPACK_INFO])
47 {
48   return umfpack_di_symbolic(n_row,n_col,Ap,Ai,Ax,Symbolic,Control,Info);
49 }
50
51 inline int umfpack_symbolic(int n_row,int n_col,
52                             const int Ap[], const int Ai[], const std::complex<double> Ax[], void **Symbolic,
53                             const double Control [UMFPACK_CONTROL], double Info [UMFPACK_INFO])
54 {
55   return umfpack_zi_symbolic(n_row,n_col,Ap,Ai,&Ax[0].real(),0,Symbolic,Control,Info);
56 }
57
58 inline int umfpack_numeric( const int Ap[], const int Ai[], const double Ax[],
59                             void *Symbolic, void **Numeric,
60                             const double Control[UMFPACK_CONTROL],double Info [UMFPACK_INFO])
61 {
62   return umfpack_di_numeric(Ap,Ai,Ax,Symbolic,Numeric,Control,Info);
63 }
64
65 inline int umfpack_numeric( const int Ap[], const int Ai[], const std::complex<double> Ax[],
66                             void *Symbolic, void **Numeric,
67                             const double Control[UMFPACK_CONTROL],double Info [UMFPACK_INFO])
68 {
69   return umfpack_zi_numeric(Ap,Ai,&Ax[0].real(),0,Symbolic,Numeric,Control,Info);
70 }
71
72 inline int umfpack_solve( int sys, const int Ap[], const int Ai[], const double Ax[],
73                           double X[], const double B[], void *Numeric,
74                           const double Control[UMFPACK_CONTROL], double Info[UMFPACK_INFO])
75 {
76   return umfpack_di_solve(sys,Ap,Ai,Ax,X,B,Numeric,Control,Info);
77 }
78
79 inline int umfpack_solve( int sys, const int Ap[], const int Ai[], const std::complex<double> Ax[],
80                           std::complex<double> X[], const std::complex<double> B[], void *Numeric,
81                           const double Control[UMFPACK_CONTROL], double Info[UMFPACK_INFO])
82 {
83   return umfpack_zi_solve(sys,Ap,Ai,&Ax[0].real(),0,&X[0].real(),0,&B[0].real(),0,Numeric,Control,Info);
84 }
85
86 inline int umfpack_get_lunz(int *lnz, int *unz, int *n_row, int *n_col, int *nz_udiag, void *Numeric, double)
87 {
88   return umfpack_di_get_lunz(lnz,unz,n_row,n_col,nz_udiag,Numeric);
89 }
90
91 inline int umfpack_get_lunz(int *lnz, int *unz, int *n_row, int *n_col, int *nz_udiag, void *Numeric, std::complex<double>)
92 {
93   return umfpack_zi_get_lunz(lnz,unz,n_row,n_col,nz_udiag,Numeric);
94 }
95
96 inline int umfpack_get_numeric(int Lp[], int Lj[], double Lx[], int Up[], int Ui[], double Ux[],
97                                int P[], int Q[], double Dx[], int *do_recip, double Rs[], void *Numeric)
98 {
99   return umfpack_di_get_numeric(Lp,Lj,Lx,Up,Ui,Ux,P,Q,Dx,do_recip,Rs,Numeric);
100 }
101
102 inline int umfpack_get_numeric(int Lp[], int Lj[], std::complex<double> Lx[], int Up[], int Ui[], std::complex<double> Ux[],
103                                int P[], int Q[], std::complex<double> Dx[], int *do_recip, double Rs[], void *Numeric)
104 {
105   return umfpack_zi_get_numeric(Lp,Lj,Lx?&Lx[0].real():0,0,Up,Ui,Ux?&Ux[0].real():0,0,P,Q,
106                                Dx?&Dx[0].real():0,0,do_recip,Rs,Numeric);
107 }
108
109 inline int umfpack_get_determinant(double *Mx, double *Ex, void *NumericHandle, double User_Info [UMFPACK_INFO])
110 {
111   return umfpack_di_get_determinant(Mx,Ex,NumericHandle,User_Info);
112 }
113
114 inline int umfpack_get_determinant(std::complex<double> *Mx, double *Ex, void *NumericHandle, double User_Info [UMFPACK_INFO])
115 {
116   return umfpack_zi_get_determinant(&Mx->real(),0,Ex,NumericHandle,User_Info);
117 }
118
119
120 template<typename MatrixType>
121 class SparseLU<MatrixType,UmfPack> : public SparseLU<MatrixType>
122 {
123   protected:
124     typedef SparseLU<MatrixType> Base;
125     typedef typename Base::Scalar Scalar;
126     typedef typename Base::RealScalar RealScalar;
127     typedef Matrix<Scalar,Dynamic,1> Vector;
128     typedef Matrix<int, 1, MatrixType::ColsAtCompileTime> IntRowVectorType;
129     typedef Matrix<int, MatrixType::RowsAtCompileTime, 1> IntColVectorType;
130     typedef SparseMatrix<Scalar,LowerTriangular|UnitDiagBit> LMatrixType;
131     typedef SparseMatrix<Scalar,UpperTriangular> UMatrixType;
132     using Base::m_flags;
133     using Base::m_status;
134
135   public:
136
137     SparseLU(int flags = NaturalOrdering)
138       : Base(flags), m_numeric(0)
139     {
140     }
141
142     SparseLU(const MatrixType& matrix, int flags = NaturalOrdering)
143       : Base(flags), m_numeric(0)
144     {
145       compute(matrix);
146     }
147
148     ~SparseLU()
149     {
150       if (m_numeric)
151         umfpack_free_numeric(&m_numeric,Scalar());
152     }
153
154     inline const LMatrixType& matrixL() const
155     {
156       if (m_extractedDataAreDirty) extractData();
157       return m_l;
158     }
159
160     inline const UMatrixType& matrixU() const
161     {
162       if (m_extractedDataAreDirty) extractData();
163       return m_u;
164     }
165
166     inline const IntColVectorType& permutationP() const
167     {
168       if (m_extractedDataAreDirty) extractData();
169       return m_p;
170     }
171
172     inline const IntRowVectorType& permutationQ() const
173     {
174       if (m_extractedDataAreDirty) extractData();
175       return m_q;
176     }
177
178     Scalar determinant() const;
179
180     template<typename BDerived, typename XDerived>
181     bool solve(const MatrixBase<BDerived> &b, MatrixBase<XDerived>* x) const;
182
183     void compute(const MatrixType& matrix);
184
185   protected:
186
187     void extractData() const;
188   
189   protected:
190     // cached data:
191     void* m_numeric;
192     const MatrixType* m_matrixRef;
193     mutable LMatrixType m_l;
194     mutable UMatrixType m_u;
195     mutable IntColVectorType m_p;
196     mutable IntRowVectorType m_q;
197     mutable bool m_extractedDataAreDirty;
198 };
199
200 template<typename MatrixType>
201 void SparseLU<MatrixType,UmfPack>::compute(const MatrixType& a)
202 {
203   const int rows = a.rows();
204   const int cols = a.cols();
205   ei_assert((MatrixType::Flags&RowMajorBit)==0 && "Row major matrices are not supported yet");
206
207   m_matrixRef = &a;
208
209   if (m_numeric)
210     umfpack_free_numeric(&m_numeric,Scalar());
211
212   void* symbolic;
213   int errorCode = 0;
214   errorCode = umfpack_symbolic(rows, cols, a._outerIndexPtr(), a._innerIndexPtr(), a._valuePtr(),
215                                   &symbolic, 0, 0);
216   if (errorCode==0)
217     errorCode = umfpack_numeric(a._outerIndexPtr(), a._innerIndexPtr(), a._valuePtr(),
218                                    symbolic, &m_numeric, 0, 0);
219
220   umfpack_free_symbolic(&symbolic,Scalar());
221
222   m_extractedDataAreDirty = true;
223
224   Base::m_succeeded = (errorCode==0);
225 }
226
227 template<typename MatrixType>
228 void SparseLU<MatrixType,UmfPack>::extractData() const
229 {
230   if (m_extractedDataAreDirty)
231   {
232     // get size of the data
233     int lnz, unz, rows, cols, nz_udiag;
234     umfpack_get_lunz(&lnz, &unz, &rows, &cols, &nz_udiag, m_numeric, Scalar());
235
236     // allocate data
237     m_l.resize(rows,std::min(rows,cols));
238     m_l.resizeNonZeros(lnz);
239     
240     m_u.resize(std::min(rows,cols),cols);
241     m_u.resizeNonZeros(unz);
242
243     m_p.resize(rows);
244     m_q.resize(cols);
245
246     // extract
247     umfpack_get_numeric(m_l._outerIndexPtr(), m_l._innerIndexPtr(), m_l._valuePtr(),
248                         m_u._outerIndexPtr(), m_u._innerIndexPtr(), m_u._valuePtr(),
249                         m_p.data(), m_q.data(), 0, 0, 0, m_numeric);
250     
251     m_extractedDataAreDirty = false;
252   }
253 }
254
255 template<typename MatrixType>
256 typename SparseLU<MatrixType,UmfPack>::Scalar SparseLU<MatrixType,UmfPack>::determinant() const
257 {
258   Scalar det;
259   umfpack_get_determinant(&det, 0, m_numeric, 0);
260   return det;
261 }
262
263 template<typename MatrixType>
264 template<typename BDerived,typename XDerived>
265 bool SparseLU<MatrixType,UmfPack>::solve(const MatrixBase<BDerived> &b, MatrixBase<XDerived> *x) const
266 {
267   //const int size = m_matrix.rows();
268   const int rhsCols = b.cols();
269 //   ei_assert(size==b.rows());
270   ei_assert((BDerived::Flags&RowMajorBit)==0 && "UmfPack backend does not support non col-major rhs yet");
271   ei_assert((XDerived::Flags&RowMajorBit)==0 && "UmfPack backend does not support non col-major result yet");
272
273   int errorCode;
274   for (int j=0; j<rhsCols; ++j)
275   {
276     errorCode = umfpack_solve(UMFPACK_A,
277         m_matrixRef->_outerIndexPtr(), m_matrixRef->_innerIndexPtr(), m_matrixRef->_valuePtr(),
278         &x->col(j).coeffRef(0), &b.const_cast_derived().col(j).coeffRef(0), m_numeric, 0, 0);
279     if (errorCode!=0)
280       return false;
281   }
282 //   errorCode = umfpack_di_solve(UMFPACK_A,
283 //       m_matrixRef._outerIndexPtr(), m_matrixRef._innerIndexPtr(), m_matrixRef._valuePtr(),
284 //       x->derived().data(), b.derived().data(), m_numeric, 0, 0);
285
286   return true;
287 }
288
289 #endif // EIGEN_UMFPACKSUPPORT_H