076dce3b8077f6d4294847f522e749fd8e3373f4
[blender.git] / intern / opennl / intern / opennl.cpp
1 /** \file opennl/intern/opennl.c
2  *  \ingroup opennlintern
3  */
4 /*
5  *
6  *  OpenNL: Numerical Library
7  *  Copyright (C) 2004 Bruno Levy
8  *
9  *  This program is free software; you can redistribute it and/or modify
10  *  it under the terms of the GNU General Public License as published by
11  *  the Free Software Foundation; either version 2 of the License, or
12  *  (at your option) any later version.
13  *
14  *  This program is distributed in the hope that it will be useful,
15  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17  *  GNU General Public License for more details.
18  *
19  *  You should have received a copy of the GNU General Public License
20  *  along with this program; if not, write to the Free Software
21  *  Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
22  *
23  *  If you modify this software, you should include a notice giving the
24  *  name of the person performing the modification, the date of modification,
25  *  and the reason for such modification.
26  *
27  *  Contact: Bruno Levy
28  *
29  *       levy@loria.fr
30  *
31  *       ISA Project
32  *       LORIA, INRIA Lorraine,
33  *       Campus Scientifique, BP 239
34  *       54506 VANDOEUVRE LES NANCY CEDEX
35  *       FRANCE
36  *
37  *  Note that the GNU General Public License does not permit incorporating
38  *  the Software into proprietary programs.
39  */
40
41 #include "ONL_opennl.h"
42
43 #include <Eigen/Sparse>
44
45 #include <algorithm>
46 #include <cassert>
47 #include <cstdlib>
48 #include <iostream>
49 #include <vector>
50
51 /* Eigen data structures */
52
53 typedef Eigen::SparseMatrix<double, Eigen::ColMajor> EigenSparseMatrix;
54 typedef Eigen::SparseLU<EigenSparseMatrix> EigenSparseSolver;
55 typedef Eigen::VectorXd EigenVectorX;
56 typedef Eigen::Triplet<double> EigenTriplet;
57
58 /* NLContext data structure */
59
60 typedef struct {
61         NLuint index;
62         NLdouble value;
63 } NLCoeff;
64
65 typedef struct {
66         NLdouble value[4];
67         NLboolean locked;
68         NLuint index;
69         std::vector<NLCoeff> a;
70 } NLVariable;
71
72 #define NL_STATE_INITIAL            0
73 #define NL_STATE_SYSTEM             1
74 #define NL_STATE_MATRIX             2
75 #define NL_STATE_MATRIX_CONSTRUCTED 3
76 #define NL_STATE_SYSTEM_CONSTRUCTED 4
77 #define NL_STATE_SYSTEM_SOLVED      5
78
79 struct NLContext {
80         NLenum state;
81
82         NLuint n;
83         NLuint m;
84
85         std::vector<EigenTriplet> Mtriplets;
86         EigenSparseMatrix M;
87         EigenSparseMatrix MtM;
88         std::vector<EigenVectorX> b;
89         std::vector<EigenVectorX> Mtb;
90         std::vector<EigenVectorX> x;
91
92         EigenSparseSolver *sparse_solver;
93
94         NLuint nb_variables;
95         std::vector<NLVariable> variable;
96
97         NLuint nb_rows;
98         NLuint nb_rhs;
99
100         NLboolean least_squares;
101         NLboolean solve_again;
102 };
103
104 NLContext *nlNewContext(void)
105 {
106         NLContext* result = new NLContext();
107         result->state = NL_STATE_INITIAL;
108         result->nb_rhs = 1;
109         return result;
110 }
111
112 void nlDeleteContext(NLContext *context)
113 {
114         context->M.resize(0, 0);
115         context->MtM.resize(0, 0);
116         context->b.clear();
117         context->Mtb.clear();
118         context->x.clear();
119         context->variable.clear();
120
121         delete context->sparse_solver;
122         context->sparse_solver = NULL;
123
124         delete context;
125 }
126
127 static void __nlCheckState(NLContext *context, NLenum state)
128 {
129         assert(context->state == state);
130 }
131
132 static void __nlTransition(NLContext *context, NLenum from_state, NLenum to_state)
133 {
134         __nlCheckState(context, from_state);
135         context->state = to_state;
136 }
137
138 /* Get/Set parameters */
139
140 void nlSolverParameteri(NLContext *context, NLenum pname, NLint param)
141 {
142         __nlCheckState(context, NL_STATE_INITIAL);
143         switch(pname) {
144         case NL_NB_VARIABLES: {
145                 assert(param > 0);
146                 context->nb_variables = (NLuint)param;
147         } break;
148         case NL_NB_ROWS: {
149                 assert(param > 0);
150                 context->nb_rows = (NLuint)param;
151         } break;
152         case NL_LEAST_SQUARES: {
153                 context->least_squares = (NLboolean)param;
154         } break;
155         case NL_NB_RIGHT_HAND_SIDES: {
156                 context->nb_rhs = (NLuint)param;
157         } break;
158         default: {
159                 assert(0);
160         } break;
161         }
162 }
163
164 /* Get/Set Lock/Unlock variables */
165
166 void nlSetVariable(NLContext *context, NLuint rhsindex, NLuint index, NLdouble value)
167 {
168         __nlCheckState(context, NL_STATE_SYSTEM);
169         context->variable[index].value[rhsindex] = value;
170 }
171
172 NLdouble nlGetVariable(NLContext *context, NLuint rhsindex, NLuint index)
173 {
174         assert(context->state != NL_STATE_INITIAL);
175         return context->variable[index].value[rhsindex];
176 }
177
178 void nlLockVariable(NLContext *context, NLuint index)
179 {
180         __nlCheckState(context, NL_STATE_SYSTEM);
181         context->variable[index].locked = true;
182 }
183
184 void nlUnlockVariable(NLContext *context, NLuint index)
185 {
186         __nlCheckState(context, NL_STATE_SYSTEM);
187         context->variable[index].locked = false;
188 }
189
190 /* System construction */
191
192 static void __nlVariablesToVector(NLContext *context)
193 {
194         NLuint i, j, nb_rhs;
195
196         nb_rhs= context->nb_rhs;
197
198         for(i=0; i<context->nb_variables; i++) {
199                 NLVariable* v = &(context->variable[i]);
200                 if(!v->locked) {
201                         for(j=0; j<nb_rhs; j++)
202                                 context->x[j][v->index] = v->value[j];
203                 }
204         }
205 }
206
207 static void __nlVectorToVariables(NLContext *context)
208 {
209         NLuint i, j, nb_rhs;
210
211         nb_rhs= context->nb_rhs;
212
213         for(i=0; i<context->nb_variables; i++) {
214                 NLVariable* v = &(context->variable[i]);
215                 if(!v->locked) {
216                         for(j=0; j<nb_rhs; j++)
217                                 v->value[j] = context->x[j][v->index];
218                 }
219         }
220 }
221
222 static void __nlBeginSystem(NLContext *context)
223 {
224         assert(context->nb_variables > 0);
225
226         if (context->solve_again)
227                 __nlTransition(context, NL_STATE_SYSTEM_SOLVED, NL_STATE_SYSTEM);
228         else {
229                 __nlTransition(context, NL_STATE_INITIAL, NL_STATE_SYSTEM);
230
231                 context->variable.resize(context->nb_variables);
232         }
233 }
234
235 static void __nlEndSystem(NLContext *context)
236 {
237         __nlTransition(context, NL_STATE_MATRIX_CONSTRUCTED, NL_STATE_SYSTEM_CONSTRUCTED);
238 }
239
240 static void __nlBeginMatrix(NLContext *context)
241 {
242         NLuint i;
243         NLuint m = 0, n = 0;
244
245         __nlTransition(context, NL_STATE_SYSTEM, NL_STATE_MATRIX);
246
247         if (!context->solve_again) {
248                 for(i=0; i<context->nb_variables; i++) {
249                         if(context->variable[i].locked)
250                                 context->variable[i].index = ~0;
251                         else
252                                 context->variable[i].index = n++;
253                 }
254
255                 m = (context->nb_rows == 0)? n: context->nb_rows;
256
257                 context->m = m;
258                 context->n = n;
259
260                 /* reserve reasonable estimate */
261                 context->Mtriplets.clear();
262                 context->Mtriplets.reserve(std::max(m, n)*3);
263
264                 context->b.resize(context->nb_rhs);
265                 context->x.resize(context->nb_rhs);
266
267                 for (i=0; i<context->nb_rhs; i++) {
268                         context->b[i].setZero(m);
269                         context->x[i].setZero(n);
270                 }
271         }
272         else {
273                 /* need to recompute b only, A is not constructed anymore */
274                 for (i=0; i<context->nb_rhs; i++)
275                         context->b[i].setZero(context->m);
276         }
277
278         __nlVariablesToVector(context);
279 }
280
281 static void __nlEndMatrixRHS(NLContext *context, NLuint rhs)
282 {
283         NLVariable *variable;
284         NLuint i, j;
285
286         EigenVectorX& b = context->b[rhs];
287
288         for(i=0; i<context->nb_variables; i++) {
289                 variable = &(context->variable[i]);
290
291                 if(variable->locked) {
292                         std::vector<NLCoeff>& a = variable->a;
293
294                         for(j=0; j<a.size(); j++) {
295                                 b[a[j].index] -= a[j].value*variable->value[rhs];
296                         }
297                 }
298         }
299
300         if(context->least_squares)
301                 context->Mtb[rhs] = context->M.transpose() * b;
302 }
303
304 static void __nlEndMatrix(NLContext *context)
305 {
306         __nlTransition(context, NL_STATE_MATRIX, NL_STATE_MATRIX_CONSTRUCTED);
307
308         if(!context->solve_again) {
309                 context->M.resize(context->m, context->n);
310                 context->M.setFromTriplets(context->Mtriplets.begin(), context->Mtriplets.end());
311                 context->Mtriplets.clear();
312
313                 if(context->least_squares) {
314                         context->MtM = context->M.transpose() * context->M;
315
316                         context->Mtb.resize(context->nb_rhs);
317                         for (NLuint rhs=0; rhs<context->nb_rhs; rhs++)
318                                 context->Mtb[rhs].setZero(context->n);
319                 }
320         }
321
322         for (NLuint rhs=0; rhs<context->nb_rhs; rhs++)
323                 __nlEndMatrixRHS(context, rhs);
324 }
325
326 void nlMatrixAdd(NLContext *context, NLuint row, NLuint col, NLdouble value)
327 {
328         __nlCheckState(context, NL_STATE_MATRIX);
329
330         if(context->solve_again)
331                 return;
332
333         if (!context->least_squares && context->variable[row].locked);
334         else if (context->variable[col].locked) {
335                 if(!context->least_squares)
336                         row = context->variable[row].index;
337
338                 NLCoeff coeff = {row, value};
339                 context->variable[col].a.push_back(coeff);
340         }
341         else {
342                 if(!context->least_squares)
343                         row = context->variable[row].index;
344                 col = context->variable[col].index;
345
346                 // direct insert into matrix is too slow, so use triplets
347                 EigenTriplet triplet(row, col, value);
348                 context->Mtriplets.push_back(triplet);
349         }
350 }
351
352 void nlRightHandSideAdd(NLContext *context, NLuint rhsindex, NLuint index, NLdouble value)
353 {
354         __nlCheckState(context, NL_STATE_MATRIX);
355
356         if(context->least_squares) {
357                 context->b[rhsindex][index] += value;
358         }
359         else {
360                 if(!context->variable[index].locked) {
361                         index = context->variable[index].index;
362                         context->b[rhsindex][index] += value;
363                 }
364         }
365 }
366
367 void nlRightHandSideSet(NLContext *context, NLuint rhsindex, NLuint index, NLdouble value)
368 {
369         __nlCheckState(context, NL_STATE_MATRIX);
370
371         if(context->least_squares) {
372                 context->b[rhsindex][index] = value;
373         }
374         else {
375                 if(!context->variable[index].locked) {
376                         index = context->variable[index].index;
377                         context->b[rhsindex][index] = value;
378                 }
379         }
380 }
381
382 void nlBegin(NLContext *context, NLenum prim)
383 {
384         switch(prim) {
385         case NL_SYSTEM: {
386                 __nlBeginSystem(context);
387         } break;
388         case NL_MATRIX: {
389                 __nlBeginMatrix(context);
390         } break;
391         default: {
392                 assert(0);
393         }
394         }
395 }
396
397 void nlEnd(NLContext *context, NLenum prim)
398 {
399         switch(prim) {
400         case NL_SYSTEM: {
401                 __nlEndSystem(context);
402         } break;
403         case NL_MATRIX: {
404                 __nlEndMatrix(context);
405         } break;
406         default: {
407                 assert(0);
408         }
409         }
410 }
411
412 void nlPrintMatrix(NLContext *context)
413 {
414         std::cout << "A:" << context->M << std::endl;
415
416         for(NLuint rhs=0; rhs<context->nb_rhs; rhs++)
417                 std::cout << "b " << rhs << ":" << context->b[rhs] << std::endl;
418
419         if (context->MtM.rows() && context->MtM.cols())
420                 std::cout << "AtA:" << context->MtM << std::endl;
421 }
422
423 /* Solving */
424
425 NLboolean nlSolve(NLContext *context, NLboolean solveAgain)
426 {
427         NLboolean result = true;
428
429         __nlCheckState(context, NL_STATE_SYSTEM_CONSTRUCTED);
430
431         if (!context->solve_again) {
432                 EigenSparseMatrix& M = (context->least_squares)? context->MtM: context->M;
433
434                 assert(M.rows() == M.cols());
435
436                 /* Convert M to compressed column format */
437                 M.makeCompressed();
438
439                 /* Perform sparse LU factorization */
440                 EigenSparseSolver *sparse_solver = new EigenSparseSolver();
441                 context->sparse_solver = sparse_solver;
442
443                 sparse_solver->analyzePattern(M);
444                 sparse_solver->factorize(M);
445
446                 result = (sparse_solver->info() == Eigen::Success);
447
448                 /* Free M, don't need it anymore at this point */
449                 M.resize(0, 0);
450         }
451
452         if (result) {
453                 /* Solve each right hand side */
454                 for(NLuint rhs=0; rhs<context->nb_rhs; rhs++) {
455                         EigenVectorX& b = (context->least_squares)? context->Mtb[rhs]: context->b[rhs];
456                         context->x[rhs] = context->sparse_solver->solve(b);
457
458                         if (context->sparse_solver->info() != Eigen::Success)
459                                 result = false;
460                 }
461
462                 if (result) {
463                         __nlVectorToVariables(context);
464
465                         if (solveAgain)
466                                 context->solve_again = true;
467
468                         __nlTransition(context, NL_STATE_SYSTEM_CONSTRUCTED, NL_STATE_SYSTEM_SOLVED);
469                 }
470         }
471
472         return result;
473 }
474