22cb1d7a0882f147bfbb9cbaae3c63dfa76f279f
[blender-staging.git] / intern / iksolver / intern / TNT / svd.h
1 /**
2  * $Id$
3  */
4
5 #ifndef SVD_H
6
7 #define SVD_H
8
9 // Compute the Single Value Decomposition of an arbitrary matrix A
10 // That is compute the 3 matrices U,W,V with U column orthogonal (m,n) 
11 // ,W a diagonal matrix and V an orthogonal square matrix s.t. 
12 // A = U.W.Vt. From this decomposition it is trivial to compute the 
13 // inverse of A as Ainv = V.Winv.tranpose(U).
14 // work_space is a temporary vector used by this class to compute
15 // intermediate values during the computation of the SVD. This should
16 // be of length a.num_cols. This is not checked
17
18 #include "tntmath.h"
19
20 namespace TNT
21 {
22
23
24 template <class MaTRiX, class VecToR >
25 void SVD(MaTRiX &a, VecToR &w,  MaTRiX &v, VecToR &work_space) {
26
27                 int n = a.num_cols();
28         int m = a.num_rows();
29
30         int flag,i,its,j,jj,k,l(0),nm(0);
31         typename MaTRiX::value_type c,f,h,x,y,z;
32         typename MaTRiX::value_type anorm(0),g(0),scale(0);
33     typename MaTRiX::value_type s(0);
34
35     work_space.newsize(n);
36
37         for (i=1;i<=n;i++) {
38                 l=i+1;
39                 work_space(i)=scale*g;
40
41                 g = (typename MaTRiX::value_type)0;
42
43                 s = (typename  MaTRiX::value_type)0;
44         scale = (typename  MaTRiX::value_type)0;
45
46                 if (i <= m) {
47                         for (k=i;k<=m;k++) scale += TNT::abs(a(k,i));
48                         if (scale > (typename  MaTRiX::value_type)0) {
49                                 for (k=i;k<=m;k++) {
50                                         a(k,i) /= scale;
51                                         s += a(k,i)*a(k,i);
52                                 }
53                                 f=a(i,i);
54                                 g = -TNT::sign(sqrt(s),f);
55                                 h=f*g-s;
56                                 a(i,i)=f-g;
57                                 if (i != n) {
58                                         for (j=l;j<=n;j++) {
59                         s = (typename  MaTRiX::value_type)0;
60                                                 for (k=i;k<=m;k++) s += a(k,i)*a(k,j);
61                                                 f=s/h;
62                                                 for (k=i;k<=m;k++) a(k,j) += f*a(k,i);
63                                         }
64                                 }
65                                 for (k=i;k<=m;k++) a(k,i) *= scale;
66                         }
67                 }
68                 w(i)=scale*g;
69         g = (typename  MaTRiX::value_type)0;
70         s = (typename  MaTRiX::value_type)0;
71         scale = (typename  MaTRiX::value_type)0;
72                 if (i <= m && i != n) {
73                         for (k=l;k<=n;k++) scale += TNT::abs(a(i,k));
74                         if (scale > (typename  MaTRiX::value_type)0) {
75                                 for (k=l;k<=n;k++) {
76                                         a(i,k) /= scale;
77                                         s += a(i,k)*a(i,k);
78                                 }
79                                 f=a(i,l);
80                                 g = -TNT::sign(sqrt(s),f);
81                                 h=f*g-s;
82                                 a(i,l)=f-g;
83                                 for (k=l;k<=n;k++) work_space(k)=a(i,k)/h;
84                                 if (i != m) {
85                                         for (j=l;j<=m;j++) {
86                         s = (typename  MaTRiX::value_type)0;
87                                                 for (k=l;k<=n;k++) s += a(j,k)*a(i,k);
88                                                 for (k=l;k<=n;k++) a(j,k) += s*work_space(k);
89                                         }
90                                 }
91                                 for (k=l;k<=n;k++) a(i,k) *= scale;
92                         }
93                 }
94                 anorm=TNT::max(anorm,(TNT::abs(w(i))+TNT::abs(work_space(i))));
95         }
96         for (i=n;i>=1;i--) {
97                 if (i < n) {
98                         if (g != (typename  MaTRiX::value_type)0) {
99                                 for (j=l;j<=n;j++)
100                                         v(j,i)=(a(i,j)/a(i,l))/g;
101                                 for (j=l;j<=n;j++) {
102                     s = (typename  MaTRiX::value_type)0;
103                                         for (k=l;k<=n;k++) s += a(i,k)*v(k,j);
104                                         for (k=l;k<=n;k++) v(k,j) += s*v(k,i);
105                                 }
106                         }
107                         for (j=l;j<=n;j++) v(i,j)=v(j,i)= (typename  MaTRiX::value_type)0;
108                 }
109                 v(i,i)= (typename  MaTRiX::value_type)1;
110                 g=work_space(i);
111                 l=i;
112         }
113         for (i=n;i>=1;i--) {
114                 l=i+1;
115                 g=w(i);
116                 if (i < n) {
117                         for (j=l;j<=n;j++) a(i,j)= (typename  MaTRiX::value_type)0;
118                 }
119                 if (g !=  (typename  MaTRiX::value_type)0) {
120                         g= ((typename  MaTRiX::value_type)1)/g;
121                         if (i != n) {
122                                 for (j=l;j<=n;j++) {
123                     s =  (typename  MaTRiX::value_type)0;
124                                         for (k=l;k<=m;k++) s += a(k,i)*a(k,j);
125                                         f=(s/a(i,i))*g;
126                                         for (k=i;k<=m;k++) a(k,j) += f*a(k,i);
127                                 }
128                         }
129                         for (j=i;j<=m;j++) a(j,i) *= g;
130                 } else {
131                         for (j=i;j<=m;j++) a(j,i)= (typename  MaTRiX::value_type)0;
132                 }
133                 ++a(i,i);
134         }
135         for (k=n;k>=1;k--) {
136                 for (its=1;its<=30;its++) {
137                         flag=1;
138                         for (l=k;l>=1;l--) {
139                                 nm=l-1;
140                                 if (TNT::abs(work_space(l))+anorm == anorm) {
141                                         flag=0;
142                                         break;
143                                 }
144                                 if (TNT::abs(w(nm))+anorm == anorm) break;
145                         }
146                         if (flag) {
147                                 c= (typename  MaTRiX::value_type)0;
148                                 s= (typename  MaTRiX::value_type)1;
149                                 for (i=l;i<=k;i++) {
150                                         f=s*work_space(i);
151                                         if (TNT::abs(f)+anorm != anorm) {
152                                                 g=w(i);
153                                                 h= (typename  MaTRiX::value_type)TNT::pythag(float(f),float(g));
154                                                 w(i)=h;
155                                                 h= ((typename  MaTRiX::value_type)1)/h;
156                                                 c=g*h;
157                                                 s=(-f*h);
158                                                 for (j=1;j<=m;j++) {
159                                                         y=a(j,nm);
160                                                         z=a(j,i);
161                                                         a(j,nm)=y*c+z*s;
162                                                         a(j,i)=z*c-y*s;
163                                                 }
164                                         }
165                                 }
166                         }
167                         z=w(k);
168                         if (l == k) {
169                                 if (z <  (typename  MaTRiX::value_type)0) {
170                                         w(k) = -z;
171                                         for (j=1;j<=n;j++) v(j,k)=(-v(j,k));
172                                 }
173                                 break;
174                         }
175
176
177 #if 1
178                         if (its == 30)
179                         {
180                                 TNTException an_exception;
181                                 an_exception.i = 0;
182                                 throw an_exception;
183
184                                 return ;
185                                 assert(false);
186                         }
187 #endif
188                         x=w(l);
189                         nm=k-1;
190                         y=w(nm);
191                         g=work_space(nm);
192                         h=work_space(k);
193                         f=((y-z)*(y+z)+(g-h)*(g+h))/(((typename  MaTRiX::value_type)2)*h*y);
194                         g=(typename  MaTRiX::value_type)TNT::pythag(float(f), float(1));
195                         f=((x-z)*(x+z)+h*((y/(f+TNT::sign(g,f)))-h))/x;
196                         c =  (typename  MaTRiX::value_type)1;
197                         s =  (typename  MaTRiX::value_type)1;
198                         for (j=l;j<=nm;j++) {
199                                 i=j+1;
200                                 g=work_space(i);
201                                 y=w(i);
202                                 h=s*g;
203                                 g=c*g;
204                                 z=(typename  MaTRiX::value_type)TNT::pythag(float(f),float(h));
205                                 work_space(j)=z;
206                                 c=f/z;
207                                 s=h/z;
208                                 f=x*c+g*s;
209                                 g=g*c-x*s;
210                                 h=y*s;
211                                 y=y*c;
212                                 for (jj=1;jj<=n;jj++) {
213                                         x=v(jj,j);
214                                         z=v(jj,i);
215                                         v(jj,j)=x*c+z*s;
216                                         v(jj,i)=z*c-x*s;
217                                 }
218                                 z=(typename  MaTRiX::value_type)TNT::pythag(float(f),float(h));
219                                 w(j)=z;
220                                 if (z !=  (typename  MaTRiX::value_type)0) {
221                                         z= ((typename  MaTRiX::value_type)1)/z;
222                                         c=f*z;
223                                         s=h*z;
224                                 }
225                                 f=(c*g)+(s*y);
226                                 x=(c*y)-(s*g);
227                                 for (jj=1;jj<=m;jj++) {
228                                         y=a(jj,j);
229                                         z=a(jj,i);
230                                         a(jj,j)=y*c+z*s;
231                                         a(jj,i)=z*c-y*s;
232                                 }
233                         }
234                         work_space(l)= (typename  MaTRiX::value_type)0;
235                         work_space(k)=f;
236                         w(k)=x;
237                 }
238         }
239 }
240
241 // A is replaced by the column orthogonal matrix U 
242
243 template <class MaTRiX, class VecToR >
244 void SVD_a( MaTRiX &a, VecToR &w,  MaTRiX &v) {
245
246         int n = a.num_cols();
247         int m = a.num_rows();
248
249         int flag,i,its,j,jj,k,l,nm;
250         typename MaTRiX::value_type anorm,c,f,g,h,s,scale,x,y,z;
251
252         VecToR work_space;
253         work_space.newsize(n);
254
255         g = scale = anorm = 0.0;
256         
257         for (i=1;i <=n;i++) {
258                 l = i+1;
259                 work_space(i) = scale*g;
260                 g = s=scale=0.0;
261
262                 if (i <= m) {
263                         for(k=i; k<=m; k++) scale += abs(a(k,i));
264
265                         if (scale) {
266                                 for (k = i; k <=m ; k++) {
267                                         a(k,i) /= scale;
268                                         s += a(k,i)*a(k,i);
269                                 }
270                                 f = a(i,i);
271                                 g = -sign(sqrt(s),f);
272                                 h = f*g -s;
273                                 a(i,i) = f-g;
274         
275                                 for (j = l; j <=n; j++) {
276                                         for (s = 0.0,k =i;k<=m;k++) s += a(k,i)*a(k,j);
277                                         f = s/h;
278                                         for (k = i; k <= m; k++) a(k,j) += f*a(k,i);
279                                 }
280                                 for (k = i; k <=m;k++) a(k,i) *= scale;
281                         }
282                 }
283
284                 w(i) = scale*g;
285                 g = s = scale = 0.0;
286
287                 if (i <=m && i != n) {
288                         for (k = l; k <=n;k++) scale += abs(a(i,k));
289                         if (scale) {
290                                 for(k = l;k <=n;k++) {
291                                         a(i,k) /= scale;
292                                         s += a(i,k) * a(i,k);
293                                 }
294
295                                 f = a(i,l);
296                                 g = -sign(sqrt(s),f);
297                                 h= f*g -s;
298                                 a(i,l) = f-g;
299                                 for(k=l;k<=n;k++) work_space(k) = a(i,k)/h;
300                                 for (j=l;j<=m;j++) {
301                                         for(s = 0.0,k=l;k<=n;k++) s+= a(j,k)*a(i,k);
302                                         for(k=l;k<=n;k++) a(j,k) += s*work_space(k);
303                                 }
304                                 for(k=l;k<=n;k++) a(i,k)*=scale;
305                         }
306                 }
307                 anorm = max(anorm,(abs(w(i)) + abs(work_space(i))));
308         }
309         for (i=n;i>=1;i--) {
310                 if (i <n) {
311                         if (g) {
312                                 for(j=l;j<=n;j++) v(j,i) = (a(i,j)/a(i,l))/g;
313                                 for(j=l;j<=n;j++) {
314                                         for(s=0.0,k=l;k<=n;k++) s += a(i,k)*v(k,j);
315                                         for(k=l; k<=n;k++) v(k,j) +=s*v(k,i);
316                                 }
317                         }
318                         for(j=l;j <=n;j++) v(i,j) = v(j,i) = 0.0;
319                 }
320                 v(i,i) = 1.0;
321                 g = work_space(i);
322                 l = i;
323         }
324
325         for (i = min(m,n);i>=1;i--) {
326                 l = i+1;
327                 g = w(i);
328                 for (j=l;j <=n;j++) a(i,j) = 0.0;
329                 if (g) {
330                         g = 1.0/g;
331                         for (j=l;j<=n;j++) {
332                                 for (s = 0.0,k=l;k<=m;k++) s += a(k,i)*a(k,j);
333                                 f = (s/a(i,i))*g;
334                                 for (k=i;k<=m;k++) a(k,j) += f*a(k,i);  
335                         }
336                         for (j=i;j<=m;j++) a(j,i)*=g;
337                 } else {
338                         for (j=i;j<=m;j++) a(j,i) = 0.0;
339                 }
340                 ++a(i,i);
341         }
342
343         for (k=n;k>=1;k--) {
344                 for (its=1;its<=30;its++) {
345                         flag=1;
346                         for(l=k;l>=1;l--) {
347                                 nm = l-1;
348                                 if (abs(work_space(l)) + anorm == anorm) {
349                                         flag = 0;
350                                         break;
351                                 }
352                                 if (abs(w(nm)) + anorm == anorm) break;
353                         }
354                         if (flag) {
355                                 c = 0.0;
356                                 s = 1.0;
357                                 for (i=l;i<=k;i++) {
358                                         f = s*work_space(i);
359                                         work_space(i) = c*work_space(i);
360                                         if (abs(f) +anorm == anorm) break;
361                                         g = w(i);
362                                         h  = pythag(f,g);
363                                         w(i) = h;
364                                         h = 1.0/h;
365                                         c = g*h;
366                                         s = -f*h;
367                                         for (j=1;j<=m;j++) {
368                                                 y= a(j,nm);
369                                                 z=a(j,i);
370                                                 a(j,nm) = y*c + z*s;
371                                                 a(j,i) = z*c - y*s;
372                                         }
373                                 }
374                         }
375                         z=w(k);
376                         if (l==k) {
377                                 if (z <0.0) {
378                                         w(k) = -z;
379                                         for (j=1;j<=n;j++) v(j,k) = -v(j,k);
380                                 }
381                                 break;
382                         }
383
384                         if (its == 30) assert(false);
385
386                         x=w(l);
387                         nm=k-1;
388                         y=w(nm);
389                         g=work_space(nm);
390                         h=work_space(k);
391                         
392                         f= ((y-z)*(y+z) + (g-h)*(g+h))/(2.0*h*y);
393                         g = pythag(f,1.0);
394                         f= ((x-z)*(x+z) + h*((y/(f + sign(g,f)))-h))/x;
395                         c=s=1.0;
396
397                         for (j=l;j<=nm;j++) {
398                                 i=j+1;
399                                 g = work_space(i);
400                                 y=w(i);
401                                 h=s*g;
402                                 g=c*g;
403                                 z=pythag(f,h);
404                                 work_space(j) = z;
405                                 c=f/z;
406                                 s=h/z;
407                                 f=x*c + g*s;
408                                 g= g*c - x*s;
409                                 h=y*s;
410                                 y*=c;
411                                 for(jj=1;jj<=n;jj++) {
412                                         x=v(jj,j);
413                                         z=v(jj,i);
414                                         v(jj,j) = x*c + z*s;
415                                         v(jj,i) = z*c- x*s;
416                                 }
417                                 z=pythag(f,h);
418                                 w(j)=z;
419                                 if(z) {
420                                         z = 1.0/z;
421                                         c=f*z;
422                                         s=h*z;
423                                 }
424                                 f=c*g + s*y;
425                                 x= c*y-s*g;
426                         
427                                 for(jj=1;jj<=m;jj++) {
428                                         y=a(jj,j);
429                                         z=a(jj,i);
430                                         a(jj,j) = y*c+z*s;
431                                         a(jj,i) = z*c - y*s;
432                                 }
433                         }
434
435                         work_space(l) = 0.0;
436                         work_space(k) = f;
437                         w(k) = x;
438                 }
439         }
440 }
441
442 }
443
444 #endif
445