Code cleanup:
[blender.git] / intern / cycles / render / graph.cpp
1 /*
2  * Copyright 2011, Blender Foundation.
3  *
4  * This program is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU General Public License
6  * as published by the Free Software Foundation; either version 2
7  * of the License, or (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software Foundation,
16  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
17  */
18
19 #include "attribute.h"
20 #include "graph.h"
21 #include "nodes.h"
22
23 #include "util_algorithm.h"
24 #include "util_debug.h"
25 #include "util_foreach.h"
26
27 CCL_NAMESPACE_BEGIN
28
29 /* Input and Output */
30
31 ShaderInput::ShaderInput(ShaderNode *parent_, const char *name_, ShaderSocketType type_)
32 {
33         parent = parent_;
34         name = name_;
35         type = type_;
36         link = NULL;
37         value = make_float3(0, 0, 0);
38         stack_offset = SVM_STACK_INVALID;
39         default_value = NONE;
40         usage = USE_ALL;
41 }
42
43 ShaderOutput::ShaderOutput(ShaderNode *parent_, const char *name_, ShaderSocketType type_)
44 {
45         parent = parent_;
46         name = name_;
47         type = type_;
48         stack_offset = SVM_STACK_INVALID;
49 }
50
51 /* Node */
52
53 ShaderNode::ShaderNode(const char *name_)
54 {
55         name = name_;
56         id = -1;
57         bump = SHADER_BUMP_NONE;
58         special_type = SHADER_SPECIAL_TYPE_NONE;
59 }
60
61 ShaderNode::~ShaderNode()
62 {
63         foreach(ShaderInput *socket, inputs)
64                 delete socket;
65
66         foreach(ShaderOutput *socket, outputs)
67                 delete socket;
68 }
69
70 ShaderInput *ShaderNode::input(const char *name)
71 {
72         foreach(ShaderInput *socket, inputs)
73                 if(strcmp(socket->name, name) == 0)
74                         return socket;
75
76         return NULL;
77 }
78
79 ShaderOutput *ShaderNode::output(const char *name)
80 {
81         foreach(ShaderOutput *socket, outputs)
82                 if(strcmp(socket->name, name) == 0)
83                         return socket;
84
85         return NULL;
86 }
87
88 ShaderInput *ShaderNode::add_input(const char *name, ShaderSocketType type, float value, int usage)
89 {
90         ShaderInput *input = new ShaderInput(this, name, type);
91         input->value.x = value;
92         input->usage = usage;
93         inputs.push_back(input);
94         return input;
95 }
96
97 ShaderInput *ShaderNode::add_input(const char *name, ShaderSocketType type, float3 value, int usage)
98 {
99         ShaderInput *input = new ShaderInput(this, name, type);
100         input->value = value;
101         input->usage = usage;
102         inputs.push_back(input);
103         return input;
104 }
105
106 ShaderInput *ShaderNode::add_input(const char *name, ShaderSocketType type, ShaderInput::DefaultValue value, int usage)
107 {
108         ShaderInput *input = add_input(name, type);
109         input->default_value = value;
110         input->usage = usage;
111         return input;
112 }
113
114 ShaderOutput *ShaderNode::add_output(const char *name, ShaderSocketType type)
115 {
116         ShaderOutput *output = new ShaderOutput(this, name, type);
117         outputs.push_back(output);
118         return output;
119 }
120
121 void ShaderNode::attributes(AttributeRequestSet *attributes)
122 {
123         foreach(ShaderInput *input, inputs) {
124                 if(!input->link) {
125                         if(input->default_value == ShaderInput::TEXTURE_GENERATED)
126                                 attributes->add(ATTR_STD_GENERATED);
127                         else if(input->default_value == ShaderInput::TEXTURE_UV)
128                                 attributes->add(ATTR_STD_UV);
129                 }
130         }
131 }
132
133 /* Graph */
134
135 ShaderGraph::ShaderGraph()
136 {
137         finalized = false;
138         num_node_ids = 0;
139         add(new OutputNode());
140 }
141
142 ShaderGraph::~ShaderGraph()
143 {
144         foreach(ShaderNode *node, nodes)
145                 delete node;
146 }
147
148 ShaderNode *ShaderGraph::add(ShaderNode *node)
149 {
150         assert(!finalized);
151         node->id = num_node_ids++;
152         nodes.push_back(node);
153         return node;
154 }
155
156 ShaderNode *ShaderGraph::output()
157 {
158         return nodes.front();
159 }
160
161 ShaderGraph *ShaderGraph::copy()
162 {
163         ShaderGraph *newgraph = new ShaderGraph();
164
165         /* copy nodes */
166         set<ShaderNode*> nodes_all;
167         foreach(ShaderNode *node, nodes)
168                 nodes_all.insert(node);
169
170         map<ShaderNode*, ShaderNode*> nodes_copy;
171         copy_nodes(nodes_all, nodes_copy);
172
173         /* add nodes (in same order, so output is still first) */
174         newgraph->nodes.clear();
175         foreach(ShaderNode *node, nodes)
176                 newgraph->add(nodes_copy[node]);
177
178         return newgraph;
179 }
180
181 void ShaderGraph::connect(ShaderOutput *from, ShaderInput *to)
182 {
183         assert(!finalized);
184         assert(from && to);
185
186         if(to->link) {
187                 fprintf(stderr, "Cycles shader graph connect: input already connected.\n");
188                 return;
189         }
190
191         if(from->type != to->type) {
192                 /* for closures we can't do automatic conversion */
193                 if(from->type == SHADER_SOCKET_CLOSURE || to->type == SHADER_SOCKET_CLOSURE) {
194                         fprintf(stderr, "Cycles shader graph connect: can only connect closure to closure "
195                                 "(%s.%s to %s.%s).\n",
196                                 from->parent->name.c_str(), from->name,
197                                 to->parent->name.c_str(), to->name);
198                         return;
199                 }
200
201                 /* add automatic conversion node in case of type mismatch */
202                 ShaderNode *convert = add(new ConvertNode(from->type, to->type, true));
203
204                 connect(from, convert->inputs[0]);
205                 connect(convert->outputs[0], to);
206         }
207         else {
208                 /* types match, just connect */
209                 to->link = from;
210                 from->links.push_back(to);
211         }
212 }
213
214 void ShaderGraph::disconnect(ShaderInput *to)
215 {
216         assert(!finalized);
217         assert(to->link);
218
219         ShaderOutput *from = to->link;
220
221         to->link = NULL;
222         from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
223 }
224
225 void ShaderGraph::finalize(bool do_bump, bool do_osl, bool do_multi_transform)
226 {
227         /* before compiling, the shader graph may undergo a number of modifications.
228          * currently we set default geometry shader inputs, and create automatic bump
229          * from displacement. a graph can be finalized only once, and should not be
230          * modified afterwards. */
231
232         if(!finalized) {
233                 clean();
234                 default_inputs(do_osl);
235                 refine_bump_nodes();
236
237                 if(do_bump)
238                         bump_from_displacement();
239
240                 if(do_multi_transform) {
241                         ShaderInput *surface_in = output()->input("Surface");
242                         ShaderInput *volume_in = output()->input("Volume");
243
244                         /* todo: make this work when surface and volume closures are tangled up */
245
246                         if(surface_in->link)
247                                 transform_multi_closure(surface_in->link->parent, NULL, false);
248                         if(volume_in->link)
249                                 transform_multi_closure(volume_in->link->parent, NULL, true);
250                 }
251
252                 finalized = true;
253         }
254 }
255
256 void ShaderGraph::find_dependencies(set<ShaderNode*>& dependencies, ShaderInput *input)
257 {
258         /* find all nodes that this input depends on directly and indirectly */
259         ShaderNode *node = (input->link)? input->link->parent: NULL;
260
261         if(node) {
262                 foreach(ShaderInput *in, node->inputs)
263                         find_dependencies(dependencies, in);
264
265                 dependencies.insert(node);
266         }
267 }
268
269 void ShaderGraph::copy_nodes(set<ShaderNode*>& nodes, map<ShaderNode*, ShaderNode*>& nnodemap)
270 {
271         /* copy a set of nodes, and the links between them. the assumption is
272          * made that all nodes that inputs are linked to are in the set too. */
273
274         /* copy nodes */
275         foreach(ShaderNode *node, nodes) {
276                 ShaderNode *nnode = node->clone();
277                 nnodemap[node] = nnode;
278
279                 nnode->inputs.clear();
280                 nnode->outputs.clear();
281
282                 foreach(ShaderInput *input, node->inputs) {
283                         ShaderInput *ninput = new ShaderInput(*input);
284                         nnode->inputs.push_back(ninput);
285
286                         ninput->parent = nnode;
287                         ninput->link = NULL;
288                 }
289
290                 foreach(ShaderOutput *output, node->outputs) {
291                         ShaderOutput *noutput = new ShaderOutput(*output);
292                         nnode->outputs.push_back(noutput);
293
294                         noutput->parent = nnode;
295                         noutput->links.clear();
296                 }
297         }
298
299         /* recreate links */
300         foreach(ShaderNode *node, nodes) {
301                 foreach(ShaderInput *input, node->inputs) {
302                         if(input->link) {
303                                 /* find new input and output */
304                                 ShaderNode *nfrom = nnodemap[input->link->parent];
305                                 ShaderNode *nto = nnodemap[input->parent];
306                                 ShaderOutput *noutput = nfrom->output(input->link->name);
307                                 ShaderInput *ninput = nto->input(input->name);
308
309                                 /* connect */
310                                 connect(noutput, ninput);
311                         }
312                 }
313         }
314 }
315
316 void ShaderGraph::remove_unneeded_nodes()
317 {
318         vector<bool> removed(num_node_ids, false);
319         bool any_node_removed = false;
320         
321         /* find and unlink proxy nodes */
322         foreach(ShaderNode *node, nodes) {
323                 if(node->special_type == SHADER_SPECIAL_TYPE_PROXY) {
324                         ProxyNode *proxy = static_cast<ProxyNode*>(node);
325                         ShaderInput *input = proxy->inputs[0];
326                         ShaderOutput *output = proxy->outputs[0];
327                         
328                         /* temp. copy of the output links list.
329                          * output->links is modified when we disconnect!
330                          */
331                         vector<ShaderInput*> links(output->links);
332                         ShaderOutput *from = input->link;
333                         
334                         /* bypass the proxy node */
335                         if(from) {
336                                 disconnect(input);
337                                 foreach(ShaderInput *to, links) {
338                                         disconnect(to);
339                                         connect(from, to);
340                                 }
341                         }
342                         else {
343                                 foreach(ShaderInput *to, links) {
344                                         /* remove any autoconvert nodes too if they lead to
345                                          * sockets with an automatically set default value */
346                                         ShaderNode *tonode = to->parent;
347
348                                         if(tonode->special_type == SHADER_SPECIAL_TYPE_AUTOCONVERT) {
349                                                 bool all_links_removed = true;
350                                                 vector<ShaderInput*> links = tonode->outputs[0]->links;
351
352                                                 foreach(ShaderInput *autoin, links) {
353                                                         if(autoin->default_value == ShaderInput::NONE)
354                                                                 all_links_removed = false;
355                                                         else
356                                                                 disconnect(autoin);
357                                                 }
358
359                                                 if(all_links_removed)
360                                                         removed[tonode->id] = true;
361                                         }
362
363                                         disconnect(to);
364                                         
365                                         /* transfer the default input value to the target socket */
366                                         to->set(input->value);
367                                         to->set(input->value_string);
368                                 }
369                         }
370                         
371                         removed[proxy->id] = true;
372                         any_node_removed = true;
373                 }
374                 else if(node->special_type == SHADER_SPECIAL_TYPE_MIX_CLOSURE) {
375                         MixClosureNode *mix = static_cast<MixClosureNode*>(node);
376
377                         /* remove useless mix closures nodes */
378                         if(mix->outputs[0]->links.size() && mix->inputs[1]->link == mix->inputs[2]->link) {
379                                 ShaderOutput *output = mix->inputs[1]->link;
380                                 vector<ShaderInput*> inputs = mix->outputs[0]->links;
381
382                                 foreach(ShaderInput *sock, mix->inputs)
383                                         if(sock->link)
384                                                 disconnect(sock);
385
386                                 foreach(ShaderInput *input, inputs) {
387                                         disconnect(input);
388                                         if(output)
389                                                 connect(output, input);
390                                 }
391                         }
392                 
393                         /* remove unused mix closure input when factor is 0.0 or 1.0 */
394                         /* check for closure links and make sure factor link is disconnected */
395                         if(mix->outputs[0]->links.size() && mix->inputs[1]->link && mix->inputs[2]->link && !mix->inputs[0]->link) {
396                                 /* factor 0.0 */
397                                 if(mix->inputs[0]->value.x == 0.0f) {
398                                         ShaderOutput *output = mix->inputs[1]->link;
399                                         vector<ShaderInput*> inputs = mix->outputs[0]->links;
400                                         
401                                         foreach(ShaderInput *sock, mix->inputs)
402                                                 if(sock->link)
403                                                         disconnect(sock);
404
405                                         foreach(ShaderInput *input, inputs) {
406                                                 disconnect(input);
407                                                 if(output)
408                                                         connect(output, input);
409                                         }
410                                 }
411                                 /* factor 1.0 */
412                                 else if(mix->inputs[0]->value.x == 1.0f) {
413                                         ShaderOutput *output = mix->inputs[2]->link;
414                                         vector<ShaderInput*> inputs = mix->outputs[0]->links;
415                                         
416                                         foreach(ShaderInput *sock, mix->inputs)
417                                                 if(sock->link)
418                                                         disconnect(sock);
419
420                                         foreach(ShaderInput *input, inputs) {
421                                                 disconnect(input);
422                                                 if(output)
423                                                         connect(output, input);
424                                         }
425                                 }
426                         }
427                 }
428         }
429
430         /* remove nodes */
431         if (any_node_removed) {
432                 list<ShaderNode*> newnodes;
433
434                 foreach(ShaderNode *node, nodes) {
435                         if(!removed[node->id])
436                                 newnodes.push_back(node);
437                         else
438                                 delete node;
439                 }
440
441                 nodes = newnodes;
442         }
443 }
444
445 void ShaderGraph::break_cycles(ShaderNode *node, vector<bool>& visited, vector<bool>& on_stack)
446 {
447         visited[node->id] = true;
448         on_stack[node->id] = true;
449
450         foreach(ShaderInput *input, node->inputs) {
451                 if(input->link) {
452                         ShaderNode *depnode = input->link->parent;
453
454                         if(on_stack[depnode->id]) {
455                                 /* break cycle */
456                                 disconnect(input);
457                                 fprintf(stderr, "Cycles shader graph: detected cycle in graph, connection removed.\n");
458                         }
459                         else if(!visited[depnode->id]) {
460                                 /* visit dependencies */
461                                 break_cycles(depnode, visited, on_stack);
462                         }
463                 }
464         }
465
466         on_stack[node->id] = false;
467 }
468
469 void ShaderGraph::clean()
470 {
471         /* remove proxy and unnecessary mix nodes */
472         remove_unneeded_nodes();
473
474         /* we do two things here: find cycles and break them, and remove unused
475          * nodes that don't feed into the output. how cycles are broken is
476          * undefined, they are invalid input, the important thing is to not crash */
477
478         vector<bool> visited(num_node_ids, false);
479         vector<bool> on_stack(num_node_ids, false);
480         
481         /* break cycles */
482         break_cycles(output(), visited, on_stack);
483
484         /* disconnect unused nodes */
485         foreach(ShaderNode *node, nodes) {
486                 if(!visited[node->id]) {
487                         foreach(ShaderInput *to, node->inputs) {
488                                 ShaderOutput *from = to->link;
489
490                                 if(from) {
491                                         to->link = NULL;
492                                         from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
493                                 }
494                         }
495                 }
496         }
497
498         /* remove unused nodes */
499         list<ShaderNode*> newnodes;
500
501         foreach(ShaderNode *node, nodes) {
502                 if(visited[node->id])
503                         newnodes.push_back(node);
504                 else
505                         delete node;
506         }
507         
508         nodes = newnodes;
509 }
510
511 void ShaderGraph::default_inputs(bool do_osl)
512 {
513         /* nodes can specify default texture coordinates, for now we give
514          * everything the position by default, except for the sky texture */
515
516         ShaderNode *geom = NULL;
517         ShaderNode *texco = NULL;
518
519         foreach(ShaderNode *node, nodes) {
520                 foreach(ShaderInput *input, node->inputs) {
521                         if(!input->link && ((input->usage & ShaderInput::USE_SVM) || do_osl)) {
522                                 if(input->default_value == ShaderInput::TEXTURE_GENERATED) {
523                                         if(!texco)
524                                                 texco = new TextureCoordinateNode();
525
526                                         connect(texco->output("Generated"), input);
527                                 }
528                                 else if(input->default_value == ShaderInput::TEXTURE_UV) {
529                                         if(!texco)
530                                                 texco = new TextureCoordinateNode();
531
532                                         connect(texco->output("UV"), input);
533                                 }
534                                 else if(input->default_value == ShaderInput::INCOMING) {
535                                         if(!geom)
536                                                 geom = new GeometryNode();
537
538                                         connect(geom->output("Incoming"), input);
539                                 }
540                                 else if(input->default_value == ShaderInput::NORMAL) {
541                                         if(!geom)
542                                                 geom = new GeometryNode();
543
544                                         connect(geom->output("Normal"), input);
545                                 }
546                                 else if(input->default_value == ShaderInput::POSITION) {
547                                         if(!geom)
548                                                 geom = new GeometryNode();
549
550                                         connect(geom->output("Position"), input);
551                                 }
552                                 else if(input->default_value == ShaderInput::TANGENT) {
553                                         if(!geom)
554                                                 geom = new GeometryNode();
555
556                                         connect(geom->output("Tangent"), input);
557                                 }
558                         }
559                 }
560         }
561
562         if(geom)
563                 add(geom);
564         if(texco)
565                 add(texco);
566 }
567
568 void ShaderGraph::refine_bump_nodes()
569 {
570         /* we transverse the node graph looking for bump nodes, when we find them,
571          * like in bump_from_displacement(), we copy the sub-graph defined from "bump"
572          * input to the inputs "center","dx" and "dy" What is in "bump" input is moved
573          * to "center" input. */
574
575         foreach(ShaderNode *node, nodes) {
576                 if(node->name == ustring("bump") && node->input("Height")->link) {
577                         ShaderInput *bump_input = node->input("Height");
578                         set<ShaderNode*> nodes_bump;
579
580                         /* make 2 extra copies of the subgraph defined in Bump input */
581                         map<ShaderNode*, ShaderNode*> nodes_dx;
582                         map<ShaderNode*, ShaderNode*> nodes_dy;
583
584                         /* find dependencies for the given input */
585                         find_dependencies(nodes_bump, bump_input );
586
587                         copy_nodes(nodes_bump, nodes_dx);
588                         copy_nodes(nodes_bump, nodes_dy);
589         
590                         /* mark nodes to indicate they are use for bump computation, so
591                            that any texture coordinates are shifted by dx/dy when sampling */
592                         foreach(ShaderNode *node, nodes_bump)
593                                 node->bump = SHADER_BUMP_CENTER;
594                         foreach(NodePair& pair, nodes_dx)
595                                 pair.second->bump = SHADER_BUMP_DX;
596                         foreach(NodePair& pair, nodes_dy)
597                                 pair.second->bump = SHADER_BUMP_DY;
598
599                         ShaderOutput *out = bump_input->link;
600                         ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name);
601                         ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name);
602
603                         connect(out_dx, node->input("SampleX"));
604                         connect(out_dy, node->input("SampleY"));
605                         
606                         /* add generated nodes */
607                         foreach(NodePair& pair, nodes_dx)
608                                 add(pair.second);
609                         foreach(NodePair& pair, nodes_dy)
610                                 add(pair.second);
611                         
612                         /* connect what is conected is bump to samplecenter input*/
613                         connect(out , node->input("SampleCenter"));
614
615                         /* bump input is just for connectivity purpose for the graph input,
616                          * we reconected this input to samplecenter, so lets disconnect it
617                          * from bump input */
618                         disconnect(bump_input);
619                 }
620         }
621 }
622
623 void ShaderGraph::bump_from_displacement()
624 {
625         /* generate bump mapping automatically from displacement. bump mapping is
626          * done using a 3-tap filter, computing the displacement at the center,
627          * and two other positions shifted by ray differentials.
628          *
629          * since the input to displacement is a node graph, we need to ensure that
630          * all texture coordinates use are shift by the ray differentials. for this
631          * reason we make 3 copies of the node subgraph defining the displacement,
632          * with each different geometry and texture coordinate nodes that generate
633          * different shifted coordinates.
634          *
635          * these 3 displacement values are then fed into the bump node, which will
636          * output the the perturbed normal. */
637
638         ShaderInput *displacement_in = output()->input("Displacement");
639
640         if(!displacement_in->link)
641                 return;
642         
643         /* find dependencies for the given input */
644         set<ShaderNode*> nodes_displace;
645         find_dependencies(nodes_displace, displacement_in);
646
647         /* copy nodes for 3 bump samples */
648         map<ShaderNode*, ShaderNode*> nodes_center;
649         map<ShaderNode*, ShaderNode*> nodes_dx;
650         map<ShaderNode*, ShaderNode*> nodes_dy;
651
652         copy_nodes(nodes_displace, nodes_center);
653         copy_nodes(nodes_displace, nodes_dx);
654         copy_nodes(nodes_displace, nodes_dy);
655
656         /* mark nodes to indicate they are use for bump computation, so
657          * that any texture coordinates are shifted by dx/dy when sampling */
658         foreach(NodePair& pair, nodes_center)
659                 pair.second->bump = SHADER_BUMP_CENTER;
660         foreach(NodePair& pair, nodes_dx)
661                 pair.second->bump = SHADER_BUMP_DX;
662         foreach(NodePair& pair, nodes_dy)
663                 pair.second->bump = SHADER_BUMP_DY;
664
665         /* add set normal node and connect the bump normal ouput to the set normal
666          * output, so it can finally set the shader normal, note we are only doing
667          * this for bump from displacement, this will be the only bump allowed to
668          * overwrite the shader normal */
669         ShaderNode *set_normal = add(new SetNormalNode());
670         
671         /* add bump node and connect copied graphs to it */
672         ShaderNode *bump = add(new BumpNode());
673
674         ShaderOutput *out = displacement_in->link;
675         ShaderOutput *out_center = nodes_center[out->parent]->output(out->name);
676         ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name);
677         ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name);
678
679         connect(out_center, bump->input("SampleCenter"));
680         connect(out_dx, bump->input("SampleX"));
681         connect(out_dy, bump->input("SampleY"));
682         
683         /* connect the bump out to the set normal in: */
684         connect(bump->output("Normal"), set_normal->input("Direction"));
685
686         /* connect bump output to normal input nodes that aren't set yet. actually
687          * this will only set the normal input to the geometry node that we created
688          * and connected to all other normal inputs already. */
689         foreach(ShaderNode *node, nodes)
690                 foreach(ShaderInput *input, node->inputs)
691                         if(!input->link && input->default_value == ShaderInput::NORMAL)
692                                 connect(set_normal->output("Normal"), input);
693
694         /* for displacement bump, clear the normal input in case the above loop
695          * connected the setnormal out to the bump normalin */
696         ShaderInput *bump_normal_in = bump->input("Normal");
697         if(bump_normal_in)
698                 bump_normal_in->link = NULL;
699
700         /* finally, add the copied nodes to the graph. we can't do this earlier
701          * because we would create dependency cycles in the above loop */
702         foreach(NodePair& pair, nodes_center)
703                 add(pair.second);
704         foreach(NodePair& pair, nodes_dx)
705                 add(pair.second);
706         foreach(NodePair& pair, nodes_dy)
707                 add(pair.second);
708 }
709
710 void ShaderGraph::transform_multi_closure(ShaderNode *node, ShaderOutput *weight_out, bool volume)
711 {
712         /* for SVM in multi closure mode, this transforms the shader mix/add part of
713          * the graph into nodes that feed weights into closure nodes. this is too
714          * avoid building a closure tree and then flattening it, and instead write it
715          * directly to an array */
716         
717         if(node->name == ustring("mix_closure") || node->name == ustring("add_closure")) {
718                 ShaderInput *fin = node->input("Fac");
719                 ShaderInput *cl1in = node->input("Closure1");
720                 ShaderInput *cl2in = node->input("Closure2");
721                 ShaderOutput *weight1_out, *weight2_out;
722
723                 if(fin) {
724                         /* mix closure: add node to mix closure weights */
725                         ShaderNode *mix_node = add(new MixClosureWeightNode());
726                         ShaderInput *fac_in = mix_node->input("Fac"); 
727                         ShaderInput *weight_in = mix_node->input("Weight"); 
728
729                         if(fin->link)
730                                 connect(fin->link, fac_in);
731                         else
732                                 fac_in->value = fin->value;
733
734                         if(weight_out)
735                                 connect(weight_out, weight_in);
736
737                         weight1_out = mix_node->output("Weight1");
738                         weight2_out = mix_node->output("Weight2");
739                 }
740                 else {
741                         /* add closure: just pass on any weights */
742                         weight1_out = weight_out;
743                         weight2_out = weight_out;
744                 }
745
746                 if(cl1in->link)
747                         transform_multi_closure(cl1in->link->parent, weight1_out, volume);
748                 if(cl2in->link)
749                         transform_multi_closure(cl2in->link->parent, weight2_out, volume);
750         }
751         else {
752                 ShaderInput *weight_in = node->input((volume)? "VolumeMixWeight": "SurfaceMixWeight");
753
754                 /* not a closure node? */
755                 if(!weight_in)
756                         return;
757
758                 /* already has a weight connected to it? add weights */
759                 if(weight_in->link || weight_in->value.x != 0.0f) {
760                         ShaderNode *math_node = add(new MathNode());
761                         ShaderInput *value1_in = math_node->input("Value1");
762                         ShaderInput *value2_in = math_node->input("Value2");
763
764                         if(weight_in->link)
765                                 connect(weight_in->link, value1_in);
766                         else
767                                 value1_in->value = weight_in->value;
768
769                         if(weight_out)
770                                 connect(weight_out, value2_in);
771                         else
772                                 value2_in->value.x = 1.0f;
773
774                         weight_out = math_node->output("Value");
775                         if(weight_in->link)
776                                 disconnect(weight_in);
777                 }
778
779                 /* connected to closure mix weight */
780                 if(weight_out)
781                         connect(weight_out, weight_in);
782                 else
783                         weight_in->value.x += 1.0f;
784         }
785 }
786
787 CCL_NAMESPACE_END
788