Code refactor: reduce special node types, use generic constant folding.
[blender.git] / intern / cycles / render / graph.cpp
1 /*
2  * Copyright 2011-2013 Blender Foundation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "attribute.h"
18 #include "graph.h"
19 #include "nodes.h"
20 #include "shader.h"
21
22 #include "util_algorithm.h"
23 #include "util_debug.h"
24 #include "util_foreach.h"
25 #include "util_queue.h"
26
27 CCL_NAMESPACE_BEGIN
28
29 namespace {
30
31 bool check_node_inputs_has_links(const ShaderNode *node)
32 {
33         foreach(const ShaderInput *in, node->inputs) {
34                 if(in->link) {
35                         return true;
36                 }
37         }
38         return false;
39 }
40
41 bool check_node_inputs_traversed(const ShaderNode *node,
42                                  const ShaderNodeSet& done)
43 {
44         foreach(const ShaderInput *in, node->inputs) {
45                 if(in->link) {
46                         if(done.find(in->link->parent) == done.end()) {
47                                 return false;
48                         }
49                 }
50         }
51         return true;
52 }
53
54 bool check_node_inputs_equals(const ShaderNode *node_a,
55                               const ShaderNode *node_b)
56 {
57         if(node_a->inputs.size() != node_b->inputs.size()) {
58                 /* Happens with BSDF closure nodes which are currently sharing the same
59                  * name for all the BSDF types, making it impossible to filter out
60                  * incompatible nodes.
61                  */
62                 return false;
63         }
64         for(int i = 0; i < node_a->inputs.size(); ++i) {
65                 ShaderInput *input_a = node_a->inputs[i],
66                             *input_b = node_b->inputs[i];
67                 if(input_a->link == NULL && input_b->link == NULL) {
68                         /* Unconnected inputs are expected to have the same value. */
69                         if(input_a->value != input_b->value) {
70                                 return false;
71                         }
72                 }
73                 else if(input_a->link != NULL && input_b->link != NULL) {
74                         /* Expect links are to come from the same exact socket. */
75                         if(input_a->link != input_b->link) {
76                                 return false;
77                         }
78                 }
79                 else {
80                         /* One socket has a link and another has not, inputs can't be
81                          * considered equal.
82                          */
83                         return false;
84                 }
85         }
86         return true;
87 }
88
89 }  /* namespace */
90
91 /* Input and Output */
92
93 ShaderInput::ShaderInput(ShaderNode *parent_, const char *name_, ShaderSocketType type_)
94 {
95         parent = parent_;
96         name = name_;
97         type = type_;
98         link = NULL;
99         value = make_float3(0.0f, 0.0f, 0.0f);
100         stack_offset = SVM_STACK_INVALID;
101         default_value = NONE;
102         usage = USE_ALL;
103 }
104
105 ShaderOutput::ShaderOutput(ShaderNode *parent_, const char *name_, ShaderSocketType type_)
106 {
107         parent = parent_;
108         name = name_;
109         type = type_;
110         stack_offset = SVM_STACK_INVALID;
111 }
112
113 /* Node */
114
115 ShaderNode::ShaderNode(const char *name_)
116 {
117         name = name_;
118         id = -1;
119         bump = SHADER_BUMP_NONE;
120         special_type = SHADER_SPECIAL_TYPE_NONE;
121 }
122
123 ShaderNode::~ShaderNode()
124 {
125         foreach(ShaderInput *socket, inputs)
126                 delete socket;
127
128         foreach(ShaderOutput *socket, outputs)
129                 delete socket;
130 }
131
132 ShaderInput *ShaderNode::input(const char *name)
133 {
134         foreach(ShaderInput *socket, inputs) {
135                 if(strcmp(socket->name, name) == 0)
136                         return socket;
137         }
138
139         return NULL;
140 }
141
142 ShaderOutput *ShaderNode::output(const char *name)
143 {
144         foreach(ShaderOutput *socket, outputs)
145                 if(strcmp(socket->name, name) == 0)
146                         return socket;
147
148         return NULL;
149 }
150
151 ShaderInput *ShaderNode::add_input(const char *name, ShaderSocketType type, float value, int usage)
152 {
153         ShaderInput *input = new ShaderInput(this, name, type);
154         input->value.x = value;
155         input->usage = usage;
156         inputs.push_back(input);
157         return input;
158 }
159
160 ShaderInput *ShaderNode::add_input(const char *name, ShaderSocketType type, float3 value, int usage)
161 {
162         ShaderInput *input = new ShaderInput(this, name, type);
163         input->value = value;
164         input->usage = usage;
165         inputs.push_back(input);
166         return input;
167 }
168
169 ShaderInput *ShaderNode::add_input(const char *name, ShaderSocketType type, ShaderInput::DefaultValue value, int usage)
170 {
171         ShaderInput *input = add_input(name, type);
172         input->default_value = value;
173         input->usage = usage;
174         return input;
175 }
176
177 ShaderOutput *ShaderNode::add_output(const char *name, ShaderSocketType type)
178 {
179         ShaderOutput *output = new ShaderOutput(this, name, type);
180         outputs.push_back(output);
181         return output;
182 }
183
184 void ShaderNode::attributes(Shader *shader, AttributeRequestSet *attributes)
185 {
186         foreach(ShaderInput *input, inputs) {
187                 if(!input->link) {
188                         if(input->default_value == ShaderInput::TEXTURE_GENERATED) {
189                                 if(shader->has_surface)
190                                         attributes->add(ATTR_STD_GENERATED);
191                                 if(shader->has_volume)
192                                         attributes->add(ATTR_STD_GENERATED_TRANSFORM);
193                         }
194                         else if(input->default_value == ShaderInput::TEXTURE_UV) {
195                                 if(shader->has_surface)
196                                         attributes->add(ATTR_STD_UV);
197                         }
198                 }
199         }
200 }
201
202 /* Graph */
203
204 ShaderGraph::ShaderGraph()
205 {
206         finalized = false;
207         num_node_ids = 0;
208         add(new OutputNode());
209 }
210
211 ShaderGraph::~ShaderGraph()
212 {
213         clear_nodes();
214 }
215
216 ShaderNode *ShaderGraph::add(ShaderNode *node)
217 {
218         assert(!finalized);
219         node->id = num_node_ids++;
220         nodes.push_back(node);
221         return node;
222 }
223
224 OutputNode *ShaderGraph::output()
225 {
226         return (OutputNode*)nodes.front();
227 }
228
229 ShaderGraph *ShaderGraph::copy()
230 {
231         ShaderGraph *newgraph = new ShaderGraph();
232
233         /* copy nodes */
234         ShaderNodeSet nodes_all;
235         foreach(ShaderNode *node, nodes)
236                 nodes_all.insert(node);
237
238         ShaderNodeMap nodes_copy;
239         copy_nodes(nodes_all, nodes_copy);
240
241         /* add nodes (in same order, so output is still first) */
242         newgraph->clear_nodes();
243         foreach(ShaderNode *node, nodes)
244                 newgraph->add(nodes_copy[node]);
245
246         return newgraph;
247 }
248
249 void ShaderGraph::connect(ShaderOutput *from, ShaderInput *to)
250 {
251         assert(!finalized);
252         assert(from && to);
253
254         if(to->link) {
255                 fprintf(stderr, "Cycles shader graph connect: input already connected.\n");
256                 return;
257         }
258
259         if(from->type != to->type) {
260                 /* for closures we can't do automatic conversion */
261                 if(from->type == SHADER_SOCKET_CLOSURE || to->type == SHADER_SOCKET_CLOSURE) {
262                         fprintf(stderr, "Cycles shader graph connect: can only connect closure to closure "
263                                 "(%s.%s to %s.%s).\n",
264                                 from->parent->name.c_str(), from->name,
265                                 to->parent->name.c_str(), to->name);
266                         return;
267                 }
268
269                 /* add automatic conversion node in case of type mismatch */
270                 ShaderNode *convert = add(new ConvertNode(from->type, to->type, true));
271
272                 connect(from, convert->inputs[0]);
273                 connect(convert->outputs[0], to);
274         }
275         else {
276                 /* types match, just connect */
277                 to->link = from;
278                 from->links.push_back(to);
279         }
280 }
281
282 void ShaderGraph::disconnect(ShaderInput *to)
283 {
284         assert(!finalized);
285         assert(to->link);
286
287         ShaderOutput *from = to->link;
288
289         to->link = NULL;
290         from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
291 }
292
293 void ShaderGraph::relink(ShaderNode *node, ShaderOutput *from, ShaderOutput *to)
294 {
295         /* Copy because disconnect modifies this list */
296         vector<ShaderInput*> outputs = from->links;
297
298         /* Bypass node by moving all links from "from" to "to" */
299         foreach(ShaderInput *sock, node->inputs) {
300                 if(sock->link)
301                         disconnect(sock);
302         }
303
304         foreach(ShaderInput *sock, outputs) {
305                 disconnect(sock);
306                 if(to)
307                         connect(to, sock);
308         }
309 }
310
311 void ShaderGraph::finalize(Scene *scene,
312                            bool do_bump,
313                            bool do_osl,
314                            bool do_simplify)
315 {
316         /* before compiling, the shader graph may undergo a number of modifications.
317          * currently we set default geometry shader inputs, and create automatic bump
318          * from displacement. a graph can be finalized only once, and should not be
319          * modified afterwards. */
320
321         if(!finalized) {
322                 clean(scene);
323                 default_inputs(do_osl);
324                 refine_bump_nodes();
325
326                 if(do_bump)
327                         bump_from_displacement();
328
329                 ShaderInput *surface_in = output()->input("Surface");
330                 ShaderInput *volume_in = output()->input("Volume");
331
332                 /* todo: make this work when surface and volume closures are tangled up */
333
334                 if(surface_in->link)
335                         transform_multi_closure(surface_in->link->parent, NULL, false);
336                 if(volume_in->link)
337                         transform_multi_closure(volume_in->link->parent, NULL, true);
338
339                 finalized = true;
340         }
341         else if(do_simplify) {
342                 simplify_settings(scene);
343         }
344 }
345
346 void ShaderGraph::find_dependencies(ShaderNodeSet& dependencies, ShaderInput *input)
347 {
348         /* find all nodes that this input depends on directly and indirectly */
349         ShaderNode *node = (input->link)? input->link->parent: NULL;
350
351         if(node != NULL && dependencies.find(node) == dependencies.end()) {
352                 foreach(ShaderInput *in, node->inputs)
353                         find_dependencies(dependencies, in);
354
355                 dependencies.insert(node);
356         }
357 }
358
359 void ShaderGraph::clear_nodes()
360 {
361         foreach(ShaderNode *node, nodes) {
362                 delete node;
363         }
364         nodes.clear();
365 }
366
367 void ShaderGraph::copy_nodes(ShaderNodeSet& nodes, ShaderNodeMap& nnodemap)
368 {
369         /* copy a set of nodes, and the links between them. the assumption is
370          * made that all nodes that inputs are linked to are in the set too. */
371
372         /* copy nodes */
373         foreach(ShaderNode *node, nodes) {
374                 ShaderNode *nnode = node->clone();
375                 nnodemap[node] = nnode;
376
377                 nnode->inputs.clear();
378                 nnode->outputs.clear();
379
380                 foreach(ShaderInput *input, node->inputs) {
381                         ShaderInput *ninput = new ShaderInput(*input);
382                         nnode->inputs.push_back(ninput);
383
384                         ninput->parent = nnode;
385                         ninput->link = NULL;
386                 }
387
388                 foreach(ShaderOutput *output, node->outputs) {
389                         ShaderOutput *noutput = new ShaderOutput(*output);
390                         nnode->outputs.push_back(noutput);
391
392                         noutput->parent = nnode;
393                         noutput->links.clear();
394                 }
395         }
396
397         /* recreate links */
398         foreach(ShaderNode *node, nodes) {
399                 foreach(ShaderInput *input, node->inputs) {
400                         if(input->link) {
401                                 /* find new input and output */
402                                 ShaderNode *nfrom = nnodemap[input->link->parent];
403                                 ShaderNode *nto = nnodemap[input->parent];
404                                 ShaderOutput *noutput = nfrom->output(input->link->name);
405                                 ShaderInput *ninput = nto->input(input->name);
406
407                                 /* connect */
408                                 connect(noutput, ninput);
409                         }
410                 }
411         }
412 }
413
414 /* Graph simplification */
415 /* ******************** */
416
417 /* Step 1: Remove proxy nodes.
418  * These only exists temporarily when exporting groups, and we must remove them
419  * early so that node->attributes() and default links do not see them.
420  */
421 void ShaderGraph::remove_proxy_nodes()
422 {
423         vector<bool> removed(num_node_ids, false);
424         bool any_node_removed = false;
425
426         foreach(ShaderNode *node, nodes) {
427                 if(node->special_type == SHADER_SPECIAL_TYPE_PROXY) {
428                         ConvertNode *proxy = static_cast<ConvertNode*>(node);
429                         ShaderInput *input = proxy->inputs[0];
430                         ShaderOutput *output = proxy->outputs[0];
431
432                         /* bypass the proxy node */
433                         if(input->link) {
434                                 relink(proxy, output, input->link);
435                         }
436                         else {
437                                 /* Copy because disconnect modifies this list */
438                                 vector<ShaderInput*> links(output->links);
439
440                                 foreach(ShaderInput *to, links) {
441                                         /* remove any autoconvert nodes too if they lead to
442                                          * sockets with an automatically set default value */
443                                         ShaderNode *tonode = to->parent;
444
445                                         if(tonode->special_type == SHADER_SPECIAL_TYPE_AUTOCONVERT) {
446                                                 bool all_links_removed = true;
447                                                 vector<ShaderInput*> links = tonode->outputs[0]->links;
448
449                                                 foreach(ShaderInput *autoin, links) {
450                                                         if(autoin->default_value == ShaderInput::NONE)
451                                                                 all_links_removed = false;
452                                                         else
453                                                                 disconnect(autoin);
454                                                 }
455
456                                                 if(all_links_removed)
457                                                         removed[tonode->id] = true;
458                                         }
459
460                                         disconnect(to);
461
462                                         /* transfer the default input value to the target socket */
463                                         to->set(input->value);
464                                         to->set(input->value_string);
465                                 }
466                         }
467
468                         removed[proxy->id] = true;
469                         any_node_removed = true;
470                 }
471         }
472
473         /* remove nodes */
474         if(any_node_removed) {
475                 list<ShaderNode*> newnodes;
476
477                 foreach(ShaderNode *node, nodes) {
478                         if(!removed[node->id])
479                                 newnodes.push_back(node);
480                         else
481                                 delete node;
482                 }
483
484                 nodes = newnodes;
485         }
486 }
487
488 /* Step 2: Constant folding.
489  * Try to constant fold some nodes, and pipe result directly to
490  * the input socket of connected nodes.
491  */
492 void ShaderGraph::constant_fold()
493 {
494         ShaderNodeSet done, scheduled;
495         queue<ShaderNode*> traverse_queue;
496
497         /* Schedule nodes which doesn't have any dependencies. */
498         foreach(ShaderNode *node, nodes) {
499                 if(!check_node_inputs_has_links(node)) {
500                         traverse_queue.push(node);
501                         scheduled.insert(node);
502                 }
503         }
504
505         while(!traverse_queue.empty()) {
506                 ShaderNode *node = traverse_queue.front();
507                 traverse_queue.pop();
508                 done.insert(node);
509                 foreach(ShaderOutput *output, node->outputs) {
510                         if (output->links.size() == 0) {
511                                 continue;
512                         }
513                         /* Schedule node which was depending on the value,
514                          * when possible. Do it before disconnect.
515                          */
516                         foreach(ShaderInput *input, output->links) {
517                                 if(scheduled.find(input->parent) != scheduled.end()) {
518                                         /* Node might not be optimized yet but scheduled already
519                                          * by other dependencies. No need to re-schedule it.
520                                          */
521                                         continue;
522                                 }
523                                 /* Schedule node if its inputs are fully done. */
524                                 if(check_node_inputs_traversed(input->parent, done)) {
525                                         traverse_queue.push(input->parent);
526                                         scheduled.insert(input->parent);
527                                 }
528                         }
529                         /* Optimize current node. */
530                         float3 optimized_value = make_float3(0.0f, 0.0f, 0.0f);
531                         if(node->constant_fold(this, output, &optimized_value)) {
532                                 /* Apply optimized value to connected sockets. */
533                                 vector<ShaderInput*> links(output->links);
534                                 foreach(ShaderInput *input, links) {
535                                         /* Assign value and disconnect the optimizedinput. */
536                                         input->value = optimized_value;
537                                         disconnect(input);
538                                 }
539                         }
540                 }
541         }
542 }
543
544 /* Step 3: Simplification. */
545 void ShaderGraph::simplify_settings(Scene *scene)
546 {
547         foreach(ShaderNode *node, nodes) {
548                 node->simplify_settings(scene);
549         }
550 }
551
552 /* Step 4: Deduplicate nodes with same settings. */
553 void ShaderGraph::deduplicate_nodes()
554 {
555         /* NOTES:
556          * - Deduplication happens for nodes which has same exact settings and same
557          *   exact input links configuration (either connected to same output or has
558          *   the same exact default value).
559          * - Deduplication happens in the bottom-top manner, so we know for fact that
560          *   all traversed nodes are either can not be deduplicated at all or were
561          *   already deduplicated.
562          */
563
564         ShaderNodeSet scheduled;
565         map<ustring, ShaderNodeSet> done;
566         queue<ShaderNode*> traverse_queue;
567
568         /* Schedule nodes which doesn't have any dependencies. */
569         foreach(ShaderNode *node, nodes) {
570                 if(!check_node_inputs_has_links(node)) {
571                         traverse_queue.push(node);
572                         scheduled.insert(node);
573                 }
574         }
575
576         while(!traverse_queue.empty()) {
577                 ShaderNode *node = traverse_queue.front();
578                 traverse_queue.pop();
579                 done[node->name].insert(node);
580                 /* Schedule the nodes which were depending on the current node. */
581                 foreach(ShaderOutput *output, node->outputs) {
582                         foreach(ShaderInput *input, output->links) {
583                                 if(scheduled.find(input->parent) != scheduled.end()) {
584                                         /* Node might not be optimized yet but scheduled already
585                                          * by other dependencies. No need to re-schedule it.
586                                          */
587                                         continue;
588                                 }
589                                 /* Schedule node if its inputs are fully done. */
590                                 if(check_node_inputs_traversed(input->parent, done[input->parent->name])) {
591                                         traverse_queue.push(input->parent);
592                                         scheduled.insert(input->parent);
593                                 }
594                         }
595                 }
596                 /* Try to merge this node with another one. */
597                 foreach(ShaderNode *other_node, done[node->name]) {
598                         if(node == other_node) {
599                                 /* Don't merge with self. */
600                                 continue;
601                         }
602                         if(node->name != other_node->name) {
603                                 /* Can only de-duplicate nodes of the same type. */
604                                 continue;
605                         }
606                         if(!check_node_inputs_equals(node, other_node)) {
607                                 /* Node inputs are different, can't merge them, */
608                                 continue;
609                         }
610                         if(!node->equals(other_node)) {
611                                 /* Node settings are different. */
612                                 continue;
613                         }
614                         /* TODO(sergey): Consider making it an utility function. */
615                         for(int i = 0; i < node->outputs.size(); ++i) {
616                                 relink(node, node->outputs[i], other_node->outputs[i]);
617                         }
618                         break;
619                 }
620         }
621 }
622
623 void ShaderGraph::break_cycles(ShaderNode *node, vector<bool>& visited, vector<bool>& on_stack)
624 {
625         visited[node->id] = true;
626         on_stack[node->id] = true;
627
628         foreach(ShaderInput *input, node->inputs) {
629                 if(input->link) {
630                         ShaderNode *depnode = input->link->parent;
631
632                         if(on_stack[depnode->id]) {
633                                 /* break cycle */
634                                 disconnect(input);
635                                 fprintf(stderr, "Cycles shader graph: detected cycle in graph, connection removed.\n");
636                         }
637                         else if(!visited[depnode->id]) {
638                                 /* visit dependencies */
639                                 break_cycles(depnode, visited, on_stack);
640                         }
641                 }
642         }
643
644         on_stack[node->id] = false;
645 }
646
647 void ShaderGraph::clean(Scene *scene)
648 {
649         /* Graph simplification */
650
651         /* 1: Remove proxy nodes was already done. */
652
653         /* 2: Constant folding. */
654         constant_fold();
655
656         /* 3: Simplification. */
657         simplify_settings(scene);
658
659         /* 4: De-duplication. */
660         deduplicate_nodes();
661
662         /* we do two things here: find cycles and break them, and remove unused
663          * nodes that don't feed into the output. how cycles are broken is
664          * undefined, they are invalid input, the important thing is to not crash */
665
666         vector<bool> visited(num_node_ids, false);
667         vector<bool> on_stack(num_node_ids, false);
668         
669         /* break cycles */
670         break_cycles(output(), visited, on_stack);
671
672         /* disconnect unused nodes */
673         foreach(ShaderNode *node, nodes) {
674                 if(!visited[node->id]) {
675                         foreach(ShaderInput *to, node->inputs) {
676                                 ShaderOutput *from = to->link;
677
678                                 if(from) {
679                                         to->link = NULL;
680                                         from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
681                                 }
682                         }
683                 }
684         }
685
686         /* remove unused nodes */
687         list<ShaderNode*> newnodes;
688
689         foreach(ShaderNode *node, nodes) {
690                 if(visited[node->id])
691                         newnodes.push_back(node);
692                 else
693                         delete node;
694         }
695
696         nodes = newnodes;
697 }
698
699 void ShaderGraph::default_inputs(bool do_osl)
700 {
701         /* nodes can specify default texture coordinates, for now we give
702          * everything the position by default, except for the sky texture */
703
704         ShaderNode *geom = NULL;
705         ShaderNode *texco = NULL;
706
707         foreach(ShaderNode *node, nodes) {
708                 foreach(ShaderInput *input, node->inputs) {
709                         if(!input->link && ((input->usage & ShaderInput::USE_SVM) || do_osl)) {
710                                 if(input->default_value == ShaderInput::TEXTURE_GENERATED) {
711                                         if(!texco)
712                                                 texco = new TextureCoordinateNode();
713
714                                         connect(texco->output("Generated"), input);
715                                 }
716                                 else if(input->default_value == ShaderInput::TEXTURE_UV) {
717                                         if(!texco)
718                                                 texco = new TextureCoordinateNode();
719
720                                         connect(texco->output("UV"), input);
721                                 }
722                                 else if(input->default_value == ShaderInput::INCOMING) {
723                                         if(!geom)
724                                                 geom = new GeometryNode();
725
726                                         connect(geom->output("Incoming"), input);
727                                 }
728                                 else if(input->default_value == ShaderInput::NORMAL) {
729                                         if(!geom)
730                                                 geom = new GeometryNode();
731
732                                         connect(geom->output("Normal"), input);
733                                 }
734                                 else if(input->default_value == ShaderInput::POSITION) {
735                                         if(!geom)
736                                                 geom = new GeometryNode();
737
738                                         connect(geom->output("Position"), input);
739                                 }
740                                 else if(input->default_value == ShaderInput::TANGENT) {
741                                         if(!geom)
742                                                 geom = new GeometryNode();
743
744                                         connect(geom->output("Tangent"), input);
745                                 }
746                         }
747                 }
748         }
749
750         if(geom)
751                 add(geom);
752         if(texco)
753                 add(texco);
754 }
755
756 void ShaderGraph::refine_bump_nodes()
757 {
758         /* we transverse the node graph looking for bump nodes, when we find them,
759          * like in bump_from_displacement(), we copy the sub-graph defined from "bump"
760          * input to the inputs "center","dx" and "dy" What is in "bump" input is moved
761          * to "center" input. */
762
763         foreach(ShaderNode *node, nodes) {
764                 if(node->special_type == SHADER_SPECIAL_TYPE_BUMP && node->input("Height")->link) {
765                         ShaderInput *bump_input = node->input("Height");
766                         ShaderNodeSet nodes_bump;
767
768                         /* make 2 extra copies of the subgraph defined in Bump input */
769                         ShaderNodeMap nodes_dx;
770                         ShaderNodeMap nodes_dy;
771
772                         /* find dependencies for the given input */
773                         find_dependencies(nodes_bump, bump_input);
774
775                         copy_nodes(nodes_bump, nodes_dx);
776                         copy_nodes(nodes_bump, nodes_dy);
777         
778                         /* mark nodes to indicate they are use for bump computation, so
779                            that any texture coordinates are shifted by dx/dy when sampling */
780                         foreach(ShaderNode *node, nodes_bump)
781                                 node->bump = SHADER_BUMP_CENTER;
782                         foreach(NodePair& pair, nodes_dx)
783                                 pair.second->bump = SHADER_BUMP_DX;
784                         foreach(NodePair& pair, nodes_dy)
785                                 pair.second->bump = SHADER_BUMP_DY;
786
787                         ShaderOutput *out = bump_input->link;
788                         ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name);
789                         ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name);
790
791                         connect(out_dx, node->input("SampleX"));
792                         connect(out_dy, node->input("SampleY"));
793                         
794                         /* add generated nodes */
795                         foreach(NodePair& pair, nodes_dx)
796                                 add(pair.second);
797                         foreach(NodePair& pair, nodes_dy)
798                                 add(pair.second);
799                         
800                         /* connect what is connected is bump to samplecenter input*/
801                         connect(out , node->input("SampleCenter"));
802
803                         /* bump input is just for connectivity purpose for the graph input,
804                          * we re-connected this input to samplecenter, so lets disconnect it
805                          * from bump input */
806                         disconnect(bump_input);
807                 }
808         }
809 }
810
811 void ShaderGraph::bump_from_displacement()
812 {
813         /* generate bump mapping automatically from displacement. bump mapping is
814          * done using a 3-tap filter, computing the displacement at the center,
815          * and two other positions shifted by ray differentials.
816          *
817          * since the input to displacement is a node graph, we need to ensure that
818          * all texture coordinates use are shift by the ray differentials. for this
819          * reason we make 3 copies of the node subgraph defining the displacement,
820          * with each different geometry and texture coordinate nodes that generate
821          * different shifted coordinates.
822          *
823          * these 3 displacement values are then fed into the bump node, which will
824          * output the perturbed normal. */
825
826         ShaderInput *displacement_in = output()->input("Displacement");
827
828         if(!displacement_in->link)
829                 return;
830         
831         /* find dependencies for the given input */
832         ShaderNodeSet nodes_displace;
833         find_dependencies(nodes_displace, displacement_in);
834
835         /* copy nodes for 3 bump samples */
836         ShaderNodeMap nodes_center;
837         ShaderNodeMap nodes_dx;
838         ShaderNodeMap nodes_dy;
839
840         copy_nodes(nodes_displace, nodes_center);
841         copy_nodes(nodes_displace, nodes_dx);
842         copy_nodes(nodes_displace, nodes_dy);
843
844         /* mark nodes to indicate they are use for bump computation, so
845          * that any texture coordinates are shifted by dx/dy when sampling */
846         foreach(NodePair& pair, nodes_center)
847                 pair.second->bump = SHADER_BUMP_CENTER;
848         foreach(NodePair& pair, nodes_dx)
849                 pair.second->bump = SHADER_BUMP_DX;
850         foreach(NodePair& pair, nodes_dy)
851                 pair.second->bump = SHADER_BUMP_DY;
852
853         /* add set normal node and connect the bump normal ouput to the set normal
854          * output, so it can finally set the shader normal, note we are only doing
855          * this for bump from displacement, this will be the only bump allowed to
856          * overwrite the shader normal */
857         ShaderNode *set_normal = add(new SetNormalNode());
858         
859         /* add bump node and connect copied graphs to it */
860         ShaderNode *bump = add(new BumpNode());
861
862         ShaderOutput *out = displacement_in->link;
863         ShaderOutput *out_center = nodes_center[out->parent]->output(out->name);
864         ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name);
865         ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name);
866
867         connect(out_center, bump->input("SampleCenter"));
868         connect(out_dx, bump->input("SampleX"));
869         connect(out_dy, bump->input("SampleY"));
870         
871         /* connect the bump out to the set normal in: */
872         connect(bump->output("Normal"), set_normal->input("Direction"));
873
874         /* connect bump output to normal input nodes that aren't set yet. actually
875          * this will only set the normal input to the geometry node that we created
876          * and connected to all other normal inputs already. */
877         foreach(ShaderNode *node, nodes) {
878                 /* Don't connect normal to the bump node we're coming from,
879                  * otherwise it'll be a cycle in graph.
880                  */
881                 if(node == bump) {
882                         continue;
883                 }
884                 foreach(ShaderInput *input, node->inputs) {
885                         if(!input->link && input->default_value == ShaderInput::NORMAL)
886                                 connect(set_normal->output("Normal"), input);
887                 }
888         }
889
890         /* for displacement bump, clear the normal input in case the above loop
891          * connected the setnormal out to the bump normalin */
892         ShaderInput *bump_normal_in = bump->input("Normal");
893         if(bump_normal_in)
894                 bump_normal_in->link = NULL;
895
896         /* finally, add the copied nodes to the graph. we can't do this earlier
897          * because we would create dependency cycles in the above loop */
898         foreach(NodePair& pair, nodes_center)
899                 add(pair.second);
900         foreach(NodePair& pair, nodes_dx)
901                 add(pair.second);
902         foreach(NodePair& pair, nodes_dy)
903                 add(pair.second);
904 }
905
906 void ShaderGraph::transform_multi_closure(ShaderNode *node, ShaderOutput *weight_out, bool volume)
907 {
908         /* for SVM in multi closure mode, this transforms the shader mix/add part of
909          * the graph into nodes that feed weights into closure nodes. this is too
910          * avoid building a closure tree and then flattening it, and instead write it
911          * directly to an array */
912         
913         if(node->name == ustring("mix_closure") || node->name == ustring("add_closure")) {
914                 ShaderInput *fin = node->input("Fac");
915                 ShaderInput *cl1in = node->input("Closure1");
916                 ShaderInput *cl2in = node->input("Closure2");
917                 ShaderOutput *weight1_out, *weight2_out;
918
919                 if(fin) {
920                         /* mix closure: add node to mix closure weights */
921                         ShaderNode *mix_node = add(new MixClosureWeightNode());
922                         ShaderInput *fac_in = mix_node->input("Fac"); 
923                         ShaderInput *weight_in = mix_node->input("Weight"); 
924
925                         if(fin->link)
926                                 connect(fin->link, fac_in);
927                         else
928                                 fac_in->value = fin->value;
929
930                         if(weight_out)
931                                 connect(weight_out, weight_in);
932
933                         weight1_out = mix_node->output("Weight1");
934                         weight2_out = mix_node->output("Weight2");
935                 }
936                 else {
937                         /* add closure: just pass on any weights */
938                         weight1_out = weight_out;
939                         weight2_out = weight_out;
940                 }
941
942                 if(cl1in->link)
943                         transform_multi_closure(cl1in->link->parent, weight1_out, volume);
944                 if(cl2in->link)
945                         transform_multi_closure(cl2in->link->parent, weight2_out, volume);
946         }
947         else {
948                 ShaderInput *weight_in = node->input((volume)? "VolumeMixWeight": "SurfaceMixWeight");
949
950                 /* not a closure node? */
951                 if(!weight_in)
952                         return;
953
954                 /* already has a weight connected to it? add weights */
955                 if(weight_in->link || weight_in->value.x != 0.0f) {
956                         ShaderNode *math_node = add(new MathNode());
957                         ShaderInput *value1_in = math_node->input("Value1");
958                         ShaderInput *value2_in = math_node->input("Value2");
959
960                         if(weight_in->link)
961                                 connect(weight_in->link, value1_in);
962                         else
963                                 value1_in->value = weight_in->value;
964
965                         if(weight_out)
966                                 connect(weight_out, value2_in);
967                         else
968                                 value2_in->value.x = 1.0f;
969
970                         weight_out = math_node->output("Value");
971                         if(weight_in->link)
972                                 disconnect(weight_in);
973                 }
974
975                 /* connected to closure mix weight */
976                 if(weight_out)
977                         connect(weight_out, weight_in);
978                 else
979                         weight_in->value.x += 1.0f;
980         }
981 }
982
983 int ShaderGraph::get_num_closures()
984 {
985         int num_closures = 0;
986         foreach(ShaderNode *node, nodes) {
987                 if(node->special_type == SHADER_SPECIAL_TYPE_CLOSURE) {
988                         BsdfNode *bsdf_node = static_cast<BsdfNode*>(node);
989                         /* TODO(sergey): Make it more generic approach, maybe some utility
990                          * macros like CLOSURE_IS_FOO()?
991                          */
992                         if(CLOSURE_IS_BSSRDF(bsdf_node->closure))
993                                 num_closures = num_closures + 3;
994                         else if(CLOSURE_IS_GLASS(bsdf_node->closure))
995                                 num_closures = num_closures + 2;
996                         else
997                                 num_closures = num_closures + 1;
998                 }
999         }
1000         return num_closures;
1001 }
1002
1003 void ShaderGraph::dump_graph(const char *filename)
1004 {
1005         FILE *fd = fopen(filename, "w");
1006
1007         if(fd == NULL) {
1008                 printf("Error opening file for dumping the graph: %s\n", filename);
1009                 return;
1010         }
1011
1012         fprintf(fd, "digraph shader_graph {\n");
1013         fprintf(fd, "ranksep=1.5\n");
1014         fprintf(fd, "rankdir=LR\n");
1015         fprintf(fd, "splines=false\n");
1016
1017         foreach(ShaderNode *node, nodes) {
1018                 fprintf(fd, "// NODE: %p\n", node);
1019                 fprintf(fd, "\"%p\" [shape=record,label=\"{", node);
1020                 if(node->inputs.size()) {
1021                         fprintf(fd, "{");
1022                         foreach(ShaderInput *socket, node->inputs) {
1023                                 if(socket != node->inputs[0]) {
1024                                         fprintf(fd, "|");
1025                                 }
1026                                 fprintf(fd, "<IN_%p>%s", socket, socket->name);
1027                         }
1028                         fprintf(fd, "}|");
1029                 }
1030                 fprintf(fd, "%s", node->name.c_str());
1031                 if(node->bump == SHADER_BUMP_CENTER) {
1032                         fprintf(fd, " (bump:center)");
1033                 }
1034                 else if(node->bump == SHADER_BUMP_DX) {
1035                         fprintf(fd, " (bump:dx)");
1036                 }
1037                 else if(node->bump == SHADER_BUMP_DY) {
1038                         fprintf(fd, " (bump:dy)");
1039                 }
1040                 if(node->outputs.size()) {
1041                         fprintf(fd, "|{");
1042                         foreach(ShaderOutput *socket, node->outputs) {
1043                                 if(socket != node->outputs[0]) {
1044                                         fprintf(fd, "|");
1045                                 }
1046                                 fprintf(fd, "<OUT_%p>%s", socket, socket->name);
1047                         }
1048                         fprintf(fd, "}");
1049                 }
1050                 fprintf(fd, "}\"]");
1051         }
1052
1053         foreach(ShaderNode *node, nodes) {
1054                 foreach(ShaderOutput *output, node->outputs) {
1055                         foreach(ShaderInput *input, output->links) {
1056                                 fprintf(fd,
1057                                         "// CONNECTION: OUT_%p->IN_%p (%s:%s)\n",
1058                                         output,
1059                                         input,
1060                                         output->name, input->name);
1061                                 fprintf(fd,
1062                                         "\"%p\":\"OUT_%p\":e -> \"%p\":\"IN_%p\":w [label=\"\"]\n",
1063                                         output->parent,
1064                                         output,
1065                                         input->parent,
1066                                         input);
1067                         }
1068                 }
1069         }
1070
1071         fprintf(fd, "}\n");
1072         fclose(fd);
1073 }
1074
1075 CCL_NAMESPACE_END
1076