Cycles: Remove meaningless volume shaders
[blender.git] / intern / cycles / render / graph.cpp
1 /*
2  * Copyright 2011-2016 Blender Foundation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "render/attribute.h"
18 #include "render/graph.h"
19 #include "render/nodes.h"
20 #include "render/scene.h"
21 #include "render/shader.h"
22 #include "render/constant_fold.h"
23
24 #include "util/util_algorithm.h"
25 #include "util/util_debug.h"
26 #include "util/util_foreach.h"
27 #include "util/util_queue.h"
28 #include "util/util_logging.h"
29
30 CCL_NAMESPACE_BEGIN
31
32 namespace {
33
34 bool check_node_inputs_has_links(const ShaderNode *node)
35 {
36         foreach(const ShaderInput *in, node->inputs) {
37                 if(in->link) {
38                         return true;
39                 }
40         }
41         return false;
42 }
43
44 bool check_node_inputs_traversed(const ShaderNode *node,
45                                  const ShaderNodeSet& done)
46 {
47         foreach(const ShaderInput *in, node->inputs) {
48                 if(in->link) {
49                         if(done.find(in->link->parent) == done.end()) {
50                                 return false;
51                         }
52                 }
53         }
54         return true;
55 }
56
57 }  /* namespace */
58
59 /* Node */
60
61 ShaderNode::ShaderNode(const NodeType *type)
62 : Node(type)
63 {
64         name = type->name;
65         id = -1;
66         bump = SHADER_BUMP_NONE;
67         special_type = SHADER_SPECIAL_TYPE_NONE;
68
69         create_inputs_outputs(type);
70 }
71
72 ShaderNode::~ShaderNode()
73 {
74         foreach(ShaderInput *socket, inputs)
75                 delete socket;
76
77         foreach(ShaderOutput *socket, outputs)
78                 delete socket;
79 }
80
81 void ShaderNode::create_inputs_outputs(const NodeType *type)
82 {
83         foreach(const SocketType& socket, type->inputs) {
84                 if(socket.flags & SocketType::LINKABLE) {
85                         inputs.push_back(new ShaderInput(socket, this));
86                 }
87         }
88
89         foreach(const SocketType& socket, type->outputs) {
90                 outputs.push_back(new ShaderOutput(socket, this));
91         }
92 }
93
94 ShaderInput *ShaderNode::input(const char *name)
95 {
96         foreach(ShaderInput *socket, inputs) {
97                 if(socket->name() == name)
98                         return socket;
99         }
100
101         return NULL;
102 }
103
104 ShaderOutput *ShaderNode::output(const char *name)
105 {
106         foreach(ShaderOutput *socket, outputs)
107                 if(socket->name() == name)
108                         return socket;
109
110         return NULL;
111 }
112
113 ShaderInput *ShaderNode::input(ustring name)
114 {
115         foreach(ShaderInput *socket, inputs) {
116                 if(socket->name() == name)
117                         return socket;
118         }
119
120         return NULL;
121 }
122
123 ShaderOutput *ShaderNode::output(ustring name)
124 {
125         foreach(ShaderOutput *socket, outputs)
126                 if(socket->name() == name)
127                         return socket;
128
129         return NULL;
130 }
131
132 void ShaderNode::attributes(Shader *shader, AttributeRequestSet *attributes)
133 {
134         foreach(ShaderInput *input, inputs) {
135                 if(!input->link) {
136                         if(input->flags() & SocketType::LINK_TEXTURE_GENERATED) {
137                                 if(shader->has_surface)
138                                         attributes->add(ATTR_STD_GENERATED);
139                                 if(shader->has_volume)
140                                         attributes->add(ATTR_STD_GENERATED_TRANSFORM);
141                         }
142                         else if(input->flags() & SocketType::LINK_TEXTURE_UV) {
143                                 if(shader->has_surface)
144                                         attributes->add(ATTR_STD_UV);
145                         }
146                 }
147         }
148 }
149
150 bool ShaderNode::equals(const ShaderNode& other)
151 {
152         if(type != other.type || bump != other.bump) {
153                 return false;
154         }
155
156         assert(inputs.size() == other.inputs.size());
157
158         /* Compare unlinkable sockets */
159         foreach(const SocketType& socket, type->inputs) {
160                 if(!(socket.flags & SocketType::LINKABLE)) {
161                         if(!Node::equals_value(other, socket)) {
162                                 return false;
163                         }
164                 }
165         }
166
167         /* Compare linkable input sockets */
168         for(int i = 0; i < inputs.size(); ++i) {
169                 ShaderInput *input_a = inputs[i],
170                             *input_b = other.inputs[i];
171                 if(input_a->link == NULL && input_b->link == NULL) {
172                         /* Unconnected inputs are expected to have the same value. */
173                         if(!Node::equals_value(other, input_a->socket_type)) {
174                                 return false;
175                         }
176                 }
177                 else if(input_a->link != NULL && input_b->link != NULL) {
178                         /* Expect links are to come from the same exact socket. */
179                         if(input_a->link != input_b->link) {
180                                 return false;
181                         }
182                 }
183                 else {
184                         /* One socket has a link and another has not, inputs can't be
185                          * considered equal.
186                          */
187                         return false;
188                 }
189         }
190
191         return true;
192 }
193
194 /* Graph */
195
196 ShaderGraph::ShaderGraph()
197 {
198         finalized = false;
199         simplified = false;
200         num_node_ids = 0;
201         add(new OutputNode());
202 }
203
204 ShaderGraph::~ShaderGraph()
205 {
206         clear_nodes();
207 }
208
209 ShaderNode *ShaderGraph::add(ShaderNode *node)
210 {
211         assert(!finalized);
212         simplified = false;
213
214         node->id = num_node_ids++;
215         nodes.push_back(node);
216         return node;
217 }
218
219 OutputNode *ShaderGraph::output()
220 {
221         return (OutputNode*)nodes.front();
222 }
223
224 ShaderGraph *ShaderGraph::copy()
225 {
226         ShaderGraph *newgraph = new ShaderGraph();
227
228         /* copy nodes */
229         ShaderNodeSet nodes_all;
230         foreach(ShaderNode *node, nodes)
231                 nodes_all.insert(node);
232
233         ShaderNodeMap nodes_copy;
234         copy_nodes(nodes_all, nodes_copy);
235
236         /* add nodes (in same order, so output is still first) */
237         newgraph->clear_nodes();
238         foreach(ShaderNode *node, nodes)
239                 newgraph->add(nodes_copy[node]);
240
241         newgraph->simplified = simplified;
242
243         return newgraph;
244 }
245
246 void ShaderGraph::connect(ShaderOutput *from, ShaderInput *to)
247 {
248         assert(!finalized);
249         assert(from && to);
250
251         if(to->link) {
252                 fprintf(stderr, "Cycles shader graph connect: input already connected.\n");
253                 return;
254         }
255
256         if(from->type() != to->type()) {
257                 /* for closures we can't do automatic conversion */
258                 if(from->type() == SocketType::CLOSURE || to->type() == SocketType::CLOSURE) {
259                         fprintf(stderr, "Cycles shader graph connect: can only connect closure to closure "
260                                 "(%s.%s to %s.%s).\n",
261                                 from->parent->name.c_str(), from->name().c_str(),
262                                 to->parent->name.c_str(), to->name().c_str());
263                         return;
264                 }
265
266                 /* add automatic conversion node in case of type mismatch */
267                 ShaderNode *convert = add(new ConvertNode(from->type(), to->type(), true));
268
269                 connect(from, convert->inputs[0]);
270                 connect(convert->outputs[0], to);
271         }
272         else {
273                 /* types match, just connect */
274                 to->link = from;
275                 from->links.push_back(to);
276         }
277 }
278
279 void ShaderGraph::disconnect(ShaderOutput *from)
280 {
281         assert(!finalized);
282         simplified = false;
283
284         foreach(ShaderInput *sock, from->links) {
285                 sock->link = NULL;
286         }
287
288         from->links.clear();
289 }
290
291 void ShaderGraph::disconnect(ShaderInput *to)
292 {
293         assert(!finalized);
294         assert(to->link);
295         simplified = false;
296
297         ShaderOutput *from = to->link;
298
299         to->link = NULL;
300         from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
301 }
302
303 void ShaderGraph::relink(ShaderNode *node, ShaderOutput *from, ShaderOutput *to)
304 {
305         simplified = false;
306
307         /* Copy because disconnect modifies this list */
308         vector<ShaderInput*> outputs = from->links;
309
310         /* Bypass node by moving all links from "from" to "to" */
311         foreach(ShaderInput *sock, node->inputs) {
312                 if(sock->link)
313                         disconnect(sock);
314         }
315
316         foreach(ShaderInput *sock, outputs) {
317                 disconnect(sock);
318                 if(to)
319                         connect(to, sock);
320         }
321 }
322
323 void ShaderGraph::simplify(Scene *scene)
324 {
325         if(!simplified) {
326                 default_inputs(scene->shader_manager->use_osl());
327                 clean(scene);
328                 refine_bump_nodes();
329
330                 simplified = true;
331         }
332 }
333
334 void ShaderGraph::finalize(Scene *scene,
335                            bool do_bump,
336                            bool do_simplify,
337                            bool bump_in_object_space)
338 {
339         /* before compiling, the shader graph may undergo a number of modifications.
340          * currently we set default geometry shader inputs, and create automatic bump
341          * from displacement. a graph can be finalized only once, and should not be
342          * modified afterwards. */
343
344         if(!finalized) {
345                 simplify(scene);
346
347                 if(do_bump)
348                         bump_from_displacement(bump_in_object_space);
349
350                 ShaderInput *surface_in = output()->input("Surface");
351                 ShaderInput *volume_in = output()->input("Volume");
352
353                 /* todo: make this work when surface and volume closures are tangled up */
354
355                 if(surface_in->link)
356                         transform_multi_closure(surface_in->link->parent, NULL, false);
357                 if(volume_in->link)
358                         transform_multi_closure(volume_in->link->parent, NULL, true);
359
360                 finalized = true;
361         }
362         else if(do_simplify) {
363                 simplify_settings(scene);
364         }
365 }
366
367 void ShaderGraph::find_dependencies(ShaderNodeSet& dependencies, ShaderInput *input)
368 {
369         /* find all nodes that this input depends on directly and indirectly */
370         ShaderNode *node = (input->link)? input->link->parent: NULL;
371
372         if(node != NULL && dependencies.find(node) == dependencies.end()) {
373                 foreach(ShaderInput *in, node->inputs)
374                         find_dependencies(dependencies, in);
375
376                 dependencies.insert(node);
377         }
378 }
379
380 void ShaderGraph::clear_nodes()
381 {
382         foreach(ShaderNode *node, nodes) {
383                 delete node;
384         }
385         nodes.clear();
386 }
387
388 void ShaderGraph::copy_nodes(ShaderNodeSet& nodes, ShaderNodeMap& nnodemap)
389 {
390         /* copy a set of nodes, and the links between them. the assumption is
391          * made that all nodes that inputs are linked to are in the set too. */
392
393         /* copy nodes */
394         foreach(ShaderNode *node, nodes) {
395                 ShaderNode *nnode = node->clone();
396                 nnodemap[node] = nnode;
397
398                 /* create new inputs and outputs to recreate links and ensure
399                  * that we still point to valid SocketType if the NodeType
400                  * changed in cloning, as it does for OSL nodes */
401                 nnode->inputs.clear();
402                 nnode->outputs.clear();
403                 nnode->create_inputs_outputs(nnode->type);
404         }
405
406         /* recreate links */
407         foreach(ShaderNode *node, nodes) {
408                 foreach(ShaderInput *input, node->inputs) {
409                         if(input->link) {
410                                 /* find new input and output */
411                                 ShaderNode *nfrom = nnodemap[input->link->parent];
412                                 ShaderNode *nto = nnodemap[input->parent];
413                                 ShaderOutput *noutput = nfrom->output(input->link->name());
414                                 ShaderInput *ninput = nto->input(input->name());
415
416                                 /* connect */
417                                 connect(noutput, ninput);
418                         }
419                 }
420         }
421 }
422
423 /* Graph simplification */
424 /* ******************** */
425
426 /* Remove proxy nodes.
427  *
428  * These only exists temporarily when exporting groups, and we must remove them
429  * early so that node->attributes() and default links do not see them.
430  */
431 void ShaderGraph::remove_proxy_nodes()
432 {
433         vector<bool> removed(num_node_ids, false);
434         bool any_node_removed = false;
435
436         foreach(ShaderNode *node, nodes) {
437                 if(node->special_type == SHADER_SPECIAL_TYPE_PROXY) {
438                         ConvertNode *proxy = static_cast<ConvertNode*>(node);
439                         ShaderInput *input = proxy->inputs[0];
440                         ShaderOutput *output = proxy->outputs[0];
441
442                         /* bypass the proxy node */
443                         if(input->link) {
444                                 relink(proxy, output, input->link);
445                         }
446                         else {
447                                 /* Copy because disconnect modifies this list */
448                                 vector<ShaderInput*> links(output->links);
449
450                                 foreach(ShaderInput *to, links) {
451                                         /* remove any autoconvert nodes too if they lead to
452                                          * sockets with an automatically set default value */
453                                         ShaderNode *tonode = to->parent;
454
455                                         if(tonode->special_type == SHADER_SPECIAL_TYPE_AUTOCONVERT) {
456                                                 bool all_links_removed = true;
457                                                 vector<ShaderInput*> links = tonode->outputs[0]->links;
458
459                                                 foreach(ShaderInput *autoin, links) {
460                                                         if(autoin->flags() & SocketType::DEFAULT_LINK_MASK)
461                                                                 disconnect(autoin);
462                                                         else
463                                                                 all_links_removed = false;
464                                                 }
465
466                                                 if(all_links_removed)
467                                                         removed[tonode->id] = true;
468                                         }
469
470                                         disconnect(to);
471
472                                         /* transfer the default input value to the target socket */
473                                         tonode->copy_value(to->socket_type, *proxy, input->socket_type);
474                                 }
475                         }
476
477                         removed[proxy->id] = true;
478                         any_node_removed = true;
479                 }
480         }
481
482         /* remove nodes */
483         if(any_node_removed) {
484                 list<ShaderNode*> newnodes;
485
486                 foreach(ShaderNode *node, nodes) {
487                         if(!removed[node->id])
488                                 newnodes.push_back(node);
489                         else
490                                 delete node;
491                 }
492
493                 nodes = newnodes;
494         }
495 }
496
497 /* Constant folding.
498  *
499  * Try to constant fold some nodes, and pipe result directly to
500  * the input socket of connected nodes.
501  */
502 void ShaderGraph::constant_fold()
503 {
504         ShaderNodeSet done, scheduled;
505         queue<ShaderNode*> traverse_queue;
506
507         bool has_displacement = (output()->input("Displacement")->link != NULL);
508
509         /* Schedule nodes which doesn't have any dependencies. */
510         foreach(ShaderNode *node, nodes) {
511                 if(!check_node_inputs_has_links(node)) {
512                         traverse_queue.push(node);
513                         scheduled.insert(node);
514                 }
515         }
516
517         while(!traverse_queue.empty()) {
518                 ShaderNode *node = traverse_queue.front();
519                 traverse_queue.pop();
520                 done.insert(node);
521                 foreach(ShaderOutput *output, node->outputs) {
522                         if(output->links.size() == 0) {
523                                 continue;
524                         }
525                         /* Schedule node which was depending on the value,
526                          * when possible. Do it before disconnect.
527                          */
528                         foreach(ShaderInput *input, output->links) {
529                                 if(scheduled.find(input->parent) != scheduled.end()) {
530                                         /* Node might not be optimized yet but scheduled already
531                                          * by other dependencies. No need to re-schedule it.
532                                          */
533                                         continue;
534                                 }
535                                 /* Schedule node if its inputs are fully done. */
536                                 if(check_node_inputs_traversed(input->parent, done)) {
537                                         traverse_queue.push(input->parent);
538                                         scheduled.insert(input->parent);
539                                 }
540                         }
541                         /* Optimize current node. */
542                         ConstantFolder folder(this, node, output);
543                         node->constant_fold(folder);
544                 }
545         }
546
547         /* Folding might have removed all nodes connected to the displacement output
548          * even tho there is displacement to be applied, so add in a value node if
549          * that happens to ensure there is still a valid graph for displacement.
550          */
551         if(has_displacement && !output()->input("Displacement")->link) {
552                 ValueNode *value = (ValueNode*)add(new ValueNode());
553                 value->value = output()->displacement;
554
555                 connect(value->output("Value"), output()->input("Displacement"));
556         }
557 }
558
559 /* Simplification. */
560 void ShaderGraph::simplify_settings(Scene *scene)
561 {
562         foreach(ShaderNode *node, nodes) {
563                 node->simplify_settings(scene);
564         }
565 }
566
567 /* Deduplicate nodes with same settings. */
568 void ShaderGraph::deduplicate_nodes()
569 {
570         /* NOTES:
571          * - Deduplication happens for nodes which has same exact settings and same
572          *   exact input links configuration (either connected to same output or has
573          *   the same exact default value).
574          * - Deduplication happens in the bottom-top manner, so we know for fact that
575          *   all traversed nodes are either can not be deduplicated at all or were
576          *   already deduplicated.
577          */
578
579         ShaderNodeSet scheduled, done;
580         map<ustring, ShaderNodeSet> candidates;
581         queue<ShaderNode*> traverse_queue;
582         int num_deduplicated = 0;
583
584         /* Schedule nodes which doesn't have any dependencies. */
585         foreach(ShaderNode *node, nodes) {
586                 if(!check_node_inputs_has_links(node)) {
587                         traverse_queue.push(node);
588                         scheduled.insert(node);
589                 }
590         }
591
592         while(!traverse_queue.empty()) {
593                 ShaderNode *node = traverse_queue.front();
594                 traverse_queue.pop();
595                 done.insert(node);
596                 /* Schedule the nodes which were depending on the current node. */
597                 bool has_output_links = false;
598                 foreach(ShaderOutput *output, node->outputs) {
599                         foreach(ShaderInput *input, output->links) {
600                                 has_output_links = true;
601                                 if(scheduled.find(input->parent) != scheduled.end()) {
602                                         /* Node might not be optimized yet but scheduled already
603                                          * by other dependencies. No need to re-schedule it.
604                                          */
605                                         continue;
606                                 }
607                                 /* Schedule node if its inputs are fully done. */
608                                 if(check_node_inputs_traversed(input->parent, done)) {
609                                         traverse_queue.push(input->parent);
610                                         scheduled.insert(input->parent);
611                                 }
612                         }
613                 }
614                 /* Only need to care about nodes that are actually used */
615                 if(!has_output_links) {
616                         continue;
617                 }
618                 /* Try to merge this node with another one. */
619                 ShaderNode *merge_with = NULL;
620                 foreach(ShaderNode *other_node, candidates[node->type->name]) {
621                         if(node != other_node && node->equals(*other_node)) {
622                                 merge_with = other_node;
623                                 break;
624                         }
625                 }
626                 /* If found an equivalent, merge; otherwise keep node for later merges */
627                 if(merge_with != NULL) {
628                         for(int i = 0; i < node->outputs.size(); ++i) {
629                                 relink(node, node->outputs[i], merge_with->outputs[i]);
630                         }
631                         num_deduplicated++;
632                 }
633                 else {
634                         candidates[node->type->name].insert(node);
635                 }
636         }
637
638         if(num_deduplicated > 0) {
639                 VLOG(1) << "Deduplicated " << num_deduplicated << " nodes.";
640         }
641 }
642
643 /* Check whether volume output has meaningful nodes, otherwise
644  * disconnect the output.
645  */
646 void ShaderGraph::verify_volume_output()
647 {
648         /* Check whether we can optimize the whole volume graph out. */
649         ShaderInput *volume_in = output()->input("Volume");
650         if(volume_in->link == NULL) {
651                 return;
652         }
653         bool has_valid_volume = false;
654         ShaderNodeSet scheduled;
655         queue<ShaderNode*> traverse_queue;
656         /* Schedule volume output. */
657         traverse_queue.push(volume_in->link->parent);
658         scheduled.insert(volume_in->link->parent);
659         /* Traverse down the tree. */
660         while(!traverse_queue.empty()) {
661                 ShaderNode *node = traverse_queue.front();
662                 traverse_queue.pop();
663                 /* Node is fully valid for volume, can't optimize anything out. */
664                 if(node->has_volume_support()) {
665                         has_valid_volume = true;
666                         break;
667                 }
668                 foreach(ShaderInput *input, node->inputs) {
669                         if(input->link == NULL) {
670                                 continue;
671                         }
672                         if(scheduled.find(input->link->parent) != scheduled.end()) {
673                                 continue;
674                         }
675                         traverse_queue.push(input->link->parent);
676                         scheduled.insert(input->link->parent);
677                 }
678         }
679         if(!has_valid_volume) {
680                 VLOG(1) << "Disconnect meaningless volume output.";
681                 disconnect(volume_in->link);
682         }
683 }
684
685 void ShaderGraph::break_cycles(ShaderNode *node, vector<bool>& visited, vector<bool>& on_stack)
686 {
687         visited[node->id] = true;
688         on_stack[node->id] = true;
689
690         foreach(ShaderInput *input, node->inputs) {
691                 if(input->link) {
692                         ShaderNode *depnode = input->link->parent;
693
694                         if(on_stack[depnode->id]) {
695                                 /* break cycle */
696                                 disconnect(input);
697                                 fprintf(stderr, "Cycles shader graph: detected cycle in graph, connection removed.\n");
698                         }
699                         else if(!visited[depnode->id]) {
700                                 /* visit dependencies */
701                                 break_cycles(depnode, visited, on_stack);
702                         }
703                 }
704         }
705
706         on_stack[node->id] = false;
707 }
708
709 void ShaderGraph::clean(Scene *scene)
710 {
711         /* Graph simplification */
712
713         /* NOTE: Remove proxy nodes was already done. */
714         constant_fold();
715         simplify_settings(scene);
716         deduplicate_nodes();
717         verify_volume_output();
718
719         /* we do two things here: find cycles and break them, and remove unused
720          * nodes that don't feed into the output. how cycles are broken is
721          * undefined, they are invalid input, the important thing is to not crash */
722
723         vector<bool> visited(num_node_ids, false);
724         vector<bool> on_stack(num_node_ids, false);
725         
726         /* break cycles */
727         break_cycles(output(), visited, on_stack);
728
729         /* disconnect unused nodes */
730         foreach(ShaderNode *node, nodes) {
731                 if(!visited[node->id]) {
732                         foreach(ShaderInput *to, node->inputs) {
733                                 ShaderOutput *from = to->link;
734
735                                 if(from) {
736                                         to->link = NULL;
737                                         from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
738                                 }
739                         }
740                 }
741         }
742
743         /* remove unused nodes */
744         list<ShaderNode*> newnodes;
745
746         foreach(ShaderNode *node, nodes) {
747                 if(visited[node->id])
748                         newnodes.push_back(node);
749                 else
750                         delete node;
751         }
752
753         nodes = newnodes;
754 }
755
756 void ShaderGraph::default_inputs(bool do_osl)
757 {
758         /* nodes can specify default texture coordinates, for now we give
759          * everything the position by default, except for the sky texture */
760
761         ShaderNode *geom = NULL;
762         ShaderNode *texco = NULL;
763
764         foreach(ShaderNode *node, nodes) {
765                 foreach(ShaderInput *input, node->inputs) {
766                         if(!input->link && (!(input->flags() & SocketType::OSL_INTERNAL) || do_osl)) {
767                                 if(input->flags() & SocketType::LINK_TEXTURE_GENERATED) {
768                                         if(!texco)
769                                                 texco = new TextureCoordinateNode();
770
771                                         connect(texco->output("Generated"), input);
772                                 }
773                                 else if(input->flags() & SocketType::LINK_TEXTURE_UV) {
774                                         if(!texco)
775                                                 texco = new TextureCoordinateNode();
776
777                                         connect(texco->output("UV"), input);
778                                 }
779                                 else if(input->flags() & SocketType::LINK_INCOMING) {
780                                         if(!geom)
781                                                 geom = new GeometryNode();
782
783                                         connect(geom->output("Incoming"), input);
784                                 }
785                                 else if(input->flags() & SocketType::LINK_NORMAL) {
786                                         if(!geom)
787                                                 geom = new GeometryNode();
788
789                                         connect(geom->output("Normal"), input);
790                                 }
791                                 else if(input->flags() & SocketType::LINK_POSITION) {
792                                         if(!geom)
793                                                 geom = new GeometryNode();
794
795                                         connect(geom->output("Position"), input);
796                                 }
797                                 else if(input->flags() & SocketType::LINK_TANGENT) {
798                                         if(!geom)
799                                                 geom = new GeometryNode();
800
801                                         connect(geom->output("Tangent"), input);
802                                 }
803                         }
804                 }
805         }
806
807         if(geom)
808                 add(geom);
809         if(texco)
810                 add(texco);
811 }
812
813 void ShaderGraph::refine_bump_nodes()
814 {
815         /* we transverse the node graph looking for bump nodes, when we find them,
816          * like in bump_from_displacement(), we copy the sub-graph defined from "bump"
817          * input to the inputs "center","dx" and "dy" What is in "bump" input is moved
818          * to "center" input. */
819
820         foreach(ShaderNode *node, nodes) {
821                 if(node->special_type == SHADER_SPECIAL_TYPE_BUMP && node->input("Height")->link) {
822                         ShaderInput *bump_input = node->input("Height");
823                         ShaderNodeSet nodes_bump;
824
825                         /* make 2 extra copies of the subgraph defined in Bump input */
826                         ShaderNodeMap nodes_dx;
827                         ShaderNodeMap nodes_dy;
828
829                         /* find dependencies for the given input */
830                         find_dependencies(nodes_bump, bump_input);
831
832                         copy_nodes(nodes_bump, nodes_dx);
833                         copy_nodes(nodes_bump, nodes_dy);
834         
835                         /* mark nodes to indicate they are use for bump computation, so
836                            that any texture coordinates are shifted by dx/dy when sampling */
837                         foreach(ShaderNode *node, nodes_bump)
838                                 node->bump = SHADER_BUMP_CENTER;
839                         foreach(NodePair& pair, nodes_dx)
840                                 pair.second->bump = SHADER_BUMP_DX;
841                         foreach(NodePair& pair, nodes_dy)
842                                 pair.second->bump = SHADER_BUMP_DY;
843
844                         ShaderOutput *out = bump_input->link;
845                         ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name());
846                         ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name());
847
848                         connect(out_dx, node->input("SampleX"));
849                         connect(out_dy, node->input("SampleY"));
850                         
851                         /* add generated nodes */
852                         foreach(NodePair& pair, nodes_dx)
853                                 add(pair.second);
854                         foreach(NodePair& pair, nodes_dy)
855                                 add(pair.second);
856                         
857                         /* connect what is connected is bump to samplecenter input*/
858                         connect(out , node->input("SampleCenter"));
859
860                         /* bump input is just for connectivity purpose for the graph input,
861                          * we re-connected this input to samplecenter, so lets disconnect it
862                          * from bump input */
863                         disconnect(bump_input);
864                 }
865         }
866 }
867
868 void ShaderGraph::bump_from_displacement(bool use_object_space)
869 {
870         /* generate bump mapping automatically from displacement. bump mapping is
871          * done using a 3-tap filter, computing the displacement at the center,
872          * and two other positions shifted by ray differentials.
873          *
874          * since the input to displacement is a node graph, we need to ensure that
875          * all texture coordinates use are shift by the ray differentials. for this
876          * reason we make 3 copies of the node subgraph defining the displacement,
877          * with each different geometry and texture coordinate nodes that generate
878          * different shifted coordinates.
879          *
880          * these 3 displacement values are then fed into the bump node, which will
881          * output the perturbed normal. */
882
883         ShaderInput *displacement_in = output()->input("Displacement");
884
885         if(!displacement_in->link)
886                 return;
887         
888         /* find dependencies for the given input */
889         ShaderNodeSet nodes_displace;
890         find_dependencies(nodes_displace, displacement_in);
891
892         /* copy nodes for 3 bump samples */
893         ShaderNodeMap nodes_center;
894         ShaderNodeMap nodes_dx;
895         ShaderNodeMap nodes_dy;
896
897         copy_nodes(nodes_displace, nodes_center);
898         copy_nodes(nodes_displace, nodes_dx);
899         copy_nodes(nodes_displace, nodes_dy);
900
901         /* mark nodes to indicate they are use for bump computation, so
902          * that any texture coordinates are shifted by dx/dy when sampling */
903         foreach(NodePair& pair, nodes_center)
904                 pair.second->bump = SHADER_BUMP_CENTER;
905         foreach(NodePair& pair, nodes_dx)
906                 pair.second->bump = SHADER_BUMP_DX;
907         foreach(NodePair& pair, nodes_dy)
908                 pair.second->bump = SHADER_BUMP_DY;
909
910         /* add set normal node and connect the bump normal ouput to the set normal
911          * output, so it can finally set the shader normal, note we are only doing
912          * this for bump from displacement, this will be the only bump allowed to
913          * overwrite the shader normal */
914         ShaderNode *set_normal = add(new SetNormalNode());
915         
916         /* add bump node and connect copied graphs to it */
917         BumpNode *bump = (BumpNode*)add(new BumpNode());
918         bump->use_object_space = use_object_space;
919
920         ShaderOutput *out = displacement_in->link;
921         ShaderOutput *out_center = nodes_center[out->parent]->output(out->name());
922         ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name());
923         ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name());
924
925         connect(out_center, bump->input("SampleCenter"));
926         connect(out_dx, bump->input("SampleX"));
927         connect(out_dy, bump->input("SampleY"));
928         
929         /* connect the bump out to the set normal in: */
930         connect(bump->output("Normal"), set_normal->input("Direction"));
931
932         /* connect to output node */
933         connect(set_normal->output("Normal"), output()->input("Normal"));
934
935         /* finally, add the copied nodes to the graph. we can't do this earlier
936          * because we would create dependency cycles in the above loop */
937         foreach(NodePair& pair, nodes_center)
938                 add(pair.second);
939         foreach(NodePair& pair, nodes_dx)
940                 add(pair.second);
941         foreach(NodePair& pair, nodes_dy)
942                 add(pair.second);
943 }
944
945 void ShaderGraph::transform_multi_closure(ShaderNode *node, ShaderOutput *weight_out, bool volume)
946 {
947         /* for SVM in multi closure mode, this transforms the shader mix/add part of
948          * the graph into nodes that feed weights into closure nodes. this is too
949          * avoid building a closure tree and then flattening it, and instead write it
950          * directly to an array */
951         
952         if(node->special_type == SHADER_SPECIAL_TYPE_COMBINE_CLOSURE) {
953                 ShaderInput *fin = node->input("Fac");
954                 ShaderInput *cl1in = node->input("Closure1");
955                 ShaderInput *cl2in = node->input("Closure2");
956                 ShaderOutput *weight1_out, *weight2_out;
957
958                 if(fin) {
959                         /* mix closure: add node to mix closure weights */
960                         MixClosureWeightNode *mix_node = new MixClosureWeightNode();
961                         add(mix_node);
962                         ShaderInput *fac_in = mix_node->input("Fac"); 
963                         ShaderInput *weight_in = mix_node->input("Weight"); 
964
965                         if(fin->link)
966                                 connect(fin->link, fac_in);
967                         else
968                                 mix_node->fac = node->get_float(fin->socket_type);
969
970                         if(weight_out)
971                                 connect(weight_out, weight_in);
972
973                         weight1_out = mix_node->output("Weight1");
974                         weight2_out = mix_node->output("Weight2");
975                 }
976                 else {
977                         /* add closure: just pass on any weights */
978                         weight1_out = weight_out;
979                         weight2_out = weight_out;
980                 }
981
982                 if(cl1in->link)
983                         transform_multi_closure(cl1in->link->parent, weight1_out, volume);
984                 if(cl2in->link)
985                         transform_multi_closure(cl2in->link->parent, weight2_out, volume);
986         }
987         else {
988                 ShaderInput *weight_in = node->input((volume)? "VolumeMixWeight": "SurfaceMixWeight");
989
990                 /* not a closure node? */
991                 if(!weight_in)
992                         return;
993
994                 /* already has a weight connected to it? add weights */
995                 float weight_value = node->get_float(weight_in->socket_type);
996                 if(weight_in->link || weight_value != 0.0f) {
997                         MathNode *math_node = new MathNode();
998                         add(math_node);
999
1000                         if(weight_in->link)
1001                                 connect(weight_in->link, math_node->input("Value1"));
1002                         else
1003                                 math_node->value1 = weight_value;
1004
1005                         if(weight_out)
1006                                 connect(weight_out, math_node->input("Value2"));
1007                         else
1008                                 math_node->value2 = 1.0f;
1009
1010                         weight_out = math_node->output("Value");
1011                         if(weight_in->link)
1012                                 disconnect(weight_in);
1013                 }
1014
1015                 /* connected to closure mix weight */
1016                 if(weight_out)
1017                         connect(weight_out, weight_in);
1018                 else
1019                         node->set(weight_in->socket_type, weight_value + 1.0f);
1020         }
1021 }
1022
1023 int ShaderGraph::get_num_closures()
1024 {
1025         int num_closures = 0;
1026         foreach(ShaderNode *node, nodes) {
1027                 ClosureType closure_type = node->get_closure_type();
1028                 if(closure_type == CLOSURE_NONE_ID) {
1029                         continue;
1030                 }
1031                 else if(CLOSURE_IS_BSSRDF(closure_type)) {
1032                         num_closures += 3;
1033                 }
1034                 else if(CLOSURE_IS_GLASS(closure_type)) {
1035                         num_closures += 2;
1036                 }
1037                 else if(CLOSURE_IS_BSDF_MULTISCATTER(closure_type)) {
1038                         num_closures += 2;
1039                 }
1040                 else {
1041                         ++num_closures;
1042                 }
1043         }
1044         return num_closures;
1045 }
1046
1047 void ShaderGraph::dump_graph(const char *filename)
1048 {
1049         FILE *fd = fopen(filename, "w");
1050
1051         if(fd == NULL) {
1052                 printf("Error opening file for dumping the graph: %s\n", filename);
1053                 return;
1054         }
1055
1056         fprintf(fd, "digraph shader_graph {\n");
1057         fprintf(fd, "ranksep=1.5\n");
1058         fprintf(fd, "rankdir=LR\n");
1059         fprintf(fd, "splines=false\n");
1060
1061         foreach(ShaderNode *node, nodes) {
1062                 fprintf(fd, "// NODE: %p\n", node);
1063                 fprintf(fd, "\"%p\" [shape=record,label=\"{", node);
1064                 if(node->inputs.size()) {
1065                         fprintf(fd, "{");
1066                         foreach(ShaderInput *socket, node->inputs) {
1067                                 if(socket != node->inputs[0]) {
1068                                         fprintf(fd, "|");
1069                                 }
1070                                 fprintf(fd, "<IN_%p>%s", socket, socket->name().c_str());
1071                         }
1072                         fprintf(fd, "}|");
1073                 }
1074                 fprintf(fd, "%s", node->name.c_str());
1075                 if(node->bump == SHADER_BUMP_CENTER) {
1076                         fprintf(fd, " (bump:center)");
1077                 }
1078                 else if(node->bump == SHADER_BUMP_DX) {
1079                         fprintf(fd, " (bump:dx)");
1080                 }
1081                 else if(node->bump == SHADER_BUMP_DY) {
1082                         fprintf(fd, " (bump:dy)");
1083                 }
1084                 if(node->outputs.size()) {
1085                         fprintf(fd, "|{");
1086                         foreach(ShaderOutput *socket, node->outputs) {
1087                                 if(socket != node->outputs[0]) {
1088                                         fprintf(fd, "|");
1089                                 }
1090                                 fprintf(fd, "<OUT_%p>%s", socket, socket->name().c_str());
1091                         }
1092                         fprintf(fd, "}");
1093                 }
1094                 fprintf(fd, "}\"]");
1095         }
1096
1097         foreach(ShaderNode *node, nodes) {
1098                 foreach(ShaderOutput *output, node->outputs) {
1099                         foreach(ShaderInput *input, output->links) {
1100                                 fprintf(fd,
1101                                         "// CONNECTION: OUT_%p->IN_%p (%s:%s)\n",
1102                                         output,
1103                                         input,
1104                                         output->name().c_str(), input->name().c_str());
1105                                 fprintf(fd,
1106                                         "\"%p\":\"OUT_%p\":e -> \"%p\":\"IN_%p\":w [label=\"\"]\n",
1107                                         output->parent,
1108                                         output,
1109                                         input->parent,
1110                                         input);
1111                         }
1112                 }
1113         }
1114
1115         fprintf(fd, "}\n");
1116         fclose(fd);
1117 }
1118
1119 CCL_NAMESPACE_END
1120