Cycles: Remove unneeded include statements
[blender-staging.git] / intern / cycles / render / graph.cpp
1 /*
2  * Copyright 2011-2016 Blender Foundation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "render/attribute.h"
18 #include "render/graph.h"
19 #include "render/nodes.h"
20 #include "render/scene.h"
21 #include "render/shader.h"
22 #include "render/constant_fold.h"
23
24 #include "util/util_algorithm.h"
25 #include "util/util_foreach.h"
26 #include "util/util_queue.h"
27 #include "util/util_logging.h"
28
29 CCL_NAMESPACE_BEGIN
30
31 namespace {
32
33 bool check_node_inputs_has_links(const ShaderNode *node)
34 {
35         foreach(const ShaderInput *in, node->inputs) {
36                 if(in->link) {
37                         return true;
38                 }
39         }
40         return false;
41 }
42
43 bool check_node_inputs_traversed(const ShaderNode *node,
44                                  const ShaderNodeSet& done)
45 {
46         foreach(const ShaderInput *in, node->inputs) {
47                 if(in->link) {
48                         if(done.find(in->link->parent) == done.end()) {
49                                 return false;
50                         }
51                 }
52         }
53         return true;
54 }
55
56 }  /* namespace */
57
58 /* Node */
59
60 ShaderNode::ShaderNode(const NodeType *type)
61 : Node(type)
62 {
63         name = type->name;
64         id = -1;
65         bump = SHADER_BUMP_NONE;
66         special_type = SHADER_SPECIAL_TYPE_NONE;
67
68         create_inputs_outputs(type);
69 }
70
71 ShaderNode::~ShaderNode()
72 {
73         foreach(ShaderInput *socket, inputs)
74                 delete socket;
75
76         foreach(ShaderOutput *socket, outputs)
77                 delete socket;
78 }
79
80 void ShaderNode::create_inputs_outputs(const NodeType *type)
81 {
82         foreach(const SocketType& socket, type->inputs) {
83                 if(socket.flags & SocketType::LINKABLE) {
84                         inputs.push_back(new ShaderInput(socket, this));
85                 }
86         }
87
88         foreach(const SocketType& socket, type->outputs) {
89                 outputs.push_back(new ShaderOutput(socket, this));
90         }
91 }
92
93 ShaderInput *ShaderNode::input(const char *name)
94 {
95         foreach(ShaderInput *socket, inputs) {
96                 if(socket->name() == name)
97                         return socket;
98         }
99
100         return NULL;
101 }
102
103 ShaderOutput *ShaderNode::output(const char *name)
104 {
105         foreach(ShaderOutput *socket, outputs)
106                 if(socket->name() == name)
107                         return socket;
108
109         return NULL;
110 }
111
112 ShaderInput *ShaderNode::input(ustring name)
113 {
114         foreach(ShaderInput *socket, inputs) {
115                 if(socket->name() == name)
116                         return socket;
117         }
118
119         return NULL;
120 }
121
122 ShaderOutput *ShaderNode::output(ustring name)
123 {
124         foreach(ShaderOutput *socket, outputs)
125                 if(socket->name() == name)
126                         return socket;
127
128         return NULL;
129 }
130
131 void ShaderNode::attributes(Shader *shader, AttributeRequestSet *attributes)
132 {
133         foreach(ShaderInput *input, inputs) {
134                 if(!input->link) {
135                         if(input->flags() & SocketType::LINK_TEXTURE_GENERATED) {
136                                 if(shader->has_surface)
137                                         attributes->add(ATTR_STD_GENERATED);
138                                 if(shader->has_volume)
139                                         attributes->add(ATTR_STD_GENERATED_TRANSFORM);
140                         }
141                         else if(input->flags() & SocketType::LINK_TEXTURE_UV) {
142                                 if(shader->has_surface)
143                                         attributes->add(ATTR_STD_UV);
144                         }
145                 }
146         }
147 }
148
149 bool ShaderNode::equals(const ShaderNode& other)
150 {
151         if(type != other.type || bump != other.bump) {
152                 return false;
153         }
154
155         assert(inputs.size() == other.inputs.size());
156
157         /* Compare unlinkable sockets */
158         foreach(const SocketType& socket, type->inputs) {
159                 if(!(socket.flags & SocketType::LINKABLE)) {
160                         if(!Node::equals_value(other, socket)) {
161                                 return false;
162                         }
163                 }
164         }
165
166         /* Compare linkable input sockets */
167         for(int i = 0; i < inputs.size(); ++i) {
168                 ShaderInput *input_a = inputs[i],
169                             *input_b = other.inputs[i];
170                 if(input_a->link == NULL && input_b->link == NULL) {
171                         /* Unconnected inputs are expected to have the same value. */
172                         if(!Node::equals_value(other, input_a->socket_type)) {
173                                 return false;
174                         }
175                 }
176                 else if(input_a->link != NULL && input_b->link != NULL) {
177                         /* Expect links are to come from the same exact socket. */
178                         if(input_a->link != input_b->link) {
179                                 return false;
180                         }
181                 }
182                 else {
183                         /* One socket has a link and another has not, inputs can't be
184                          * considered equal.
185                          */
186                         return false;
187                 }
188         }
189
190         return true;
191 }
192
193 /* Graph */
194
195 ShaderGraph::ShaderGraph()
196 {
197         finalized = false;
198         simplified = false;
199         num_node_ids = 0;
200         add(new OutputNode());
201 }
202
203 ShaderGraph::~ShaderGraph()
204 {
205         clear_nodes();
206 }
207
208 ShaderNode *ShaderGraph::add(ShaderNode *node)
209 {
210         assert(!finalized);
211         simplified = false;
212
213         node->id = num_node_ids++;
214         nodes.push_back(node);
215         return node;
216 }
217
218 OutputNode *ShaderGraph::output()
219 {
220         return (OutputNode*)nodes.front();
221 }
222
223 void ShaderGraph::connect(ShaderOutput *from, ShaderInput *to)
224 {
225         assert(!finalized);
226         assert(from && to);
227
228         if(to->link) {
229                 fprintf(stderr, "Cycles shader graph connect: input already connected.\n");
230                 return;
231         }
232
233         if(from->type() != to->type()) {
234                 /* for closures we can't do automatic conversion */
235                 if(from->type() == SocketType::CLOSURE || to->type() == SocketType::CLOSURE) {
236                         fprintf(stderr, "Cycles shader graph connect: can only connect closure to closure "
237                                 "(%s.%s to %s.%s).\n",
238                                 from->parent->name.c_str(), from->name().c_str(),
239                                 to->parent->name.c_str(), to->name().c_str());
240                         return;
241                 }
242
243                 /* add automatic conversion node in case of type mismatch */
244                 ShaderNode *convert = add(new ConvertNode(from->type(), to->type(), true));
245
246                 connect(from, convert->inputs[0]);
247                 connect(convert->outputs[0], to);
248         }
249         else {
250                 /* types match, just connect */
251                 to->link = from;
252                 from->links.push_back(to);
253         }
254 }
255
256 void ShaderGraph::disconnect(ShaderOutput *from)
257 {
258         assert(!finalized);
259         simplified = false;
260
261         foreach(ShaderInput *sock, from->links) {
262                 sock->link = NULL;
263         }
264
265         from->links.clear();
266 }
267
268 void ShaderGraph::disconnect(ShaderInput *to)
269 {
270         assert(!finalized);
271         assert(to->link);
272         simplified = false;
273
274         ShaderOutput *from = to->link;
275
276         to->link = NULL;
277         from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
278 }
279
280 void ShaderGraph::relink(ShaderNode *node, ShaderOutput *from, ShaderOutput *to)
281 {
282         simplified = false;
283
284         /* Copy because disconnect modifies this list */
285         vector<ShaderInput*> outputs = from->links;
286
287         /* Bypass node by moving all links from "from" to "to" */
288         foreach(ShaderInput *sock, node->inputs) {
289                 if(sock->link)
290                         disconnect(sock);
291         }
292
293         foreach(ShaderInput *sock, outputs) {
294                 disconnect(sock);
295                 if(to)
296                         connect(to, sock);
297         }
298 }
299
300 void ShaderGraph::simplify(Scene *scene)
301 {
302         if(!simplified) {
303                 default_inputs(scene->shader_manager->use_osl());
304                 clean(scene);
305                 refine_bump_nodes();
306
307                 simplified = true;
308         }
309 }
310
311 void ShaderGraph::finalize(Scene *scene,
312                            bool do_bump,
313                            bool do_simplify,
314                            bool bump_in_object_space)
315 {
316         /* before compiling, the shader graph may undergo a number of modifications.
317          * currently we set default geometry shader inputs, and create automatic bump
318          * from displacement. a graph can be finalized only once, and should not be
319          * modified afterwards. */
320
321         if(!finalized) {
322                 simplify(scene);
323
324                 if(do_bump)
325                         bump_from_displacement(bump_in_object_space);
326
327                 ShaderInput *surface_in = output()->input("Surface");
328                 ShaderInput *volume_in = output()->input("Volume");
329
330                 /* todo: make this work when surface and volume closures are tangled up */
331
332                 if(surface_in->link)
333                         transform_multi_closure(surface_in->link->parent, NULL, false);
334                 if(volume_in->link)
335                         transform_multi_closure(volume_in->link->parent, NULL, true);
336
337                 finalized = true;
338         }
339         else if(do_simplify) {
340                 simplify_settings(scene);
341         }
342 }
343
344 void ShaderGraph::find_dependencies(ShaderNodeSet& dependencies, ShaderInput *input)
345 {
346         /* find all nodes that this input depends on directly and indirectly */
347         ShaderNode *node = (input->link)? input->link->parent: NULL;
348
349         if(node != NULL && dependencies.find(node) == dependencies.end()) {
350                 foreach(ShaderInput *in, node->inputs)
351                         find_dependencies(dependencies, in);
352
353                 dependencies.insert(node);
354         }
355 }
356
357 void ShaderGraph::clear_nodes()
358 {
359         foreach(ShaderNode *node, nodes) {
360                 delete node;
361         }
362         nodes.clear();
363 }
364
365 void ShaderGraph::copy_nodes(ShaderNodeSet& nodes, ShaderNodeMap& nnodemap)
366 {
367         /* copy a set of nodes, and the links between them. the assumption is
368          * made that all nodes that inputs are linked to are in the set too. */
369
370         /* copy nodes */
371         foreach(ShaderNode *node, nodes) {
372                 ShaderNode *nnode = node->clone();
373                 nnodemap[node] = nnode;
374
375                 /* create new inputs and outputs to recreate links and ensure
376                  * that we still point to valid SocketType if the NodeType
377                  * changed in cloning, as it does for OSL nodes */
378                 nnode->inputs.clear();
379                 nnode->outputs.clear();
380                 nnode->create_inputs_outputs(nnode->type);
381         }
382
383         /* recreate links */
384         foreach(ShaderNode *node, nodes) {
385                 foreach(ShaderInput *input, node->inputs) {
386                         if(input->link) {
387                                 /* find new input and output */
388                                 ShaderNode *nfrom = nnodemap[input->link->parent];
389                                 ShaderNode *nto = nnodemap[input->parent];
390                                 ShaderOutput *noutput = nfrom->output(input->link->name());
391                                 ShaderInput *ninput = nto->input(input->name());
392
393                                 /* connect */
394                                 connect(noutput, ninput);
395                         }
396                 }
397         }
398 }
399
400 /* Graph simplification */
401 /* ******************** */
402
403 /* Remove proxy nodes.
404  *
405  * These only exists temporarily when exporting groups, and we must remove them
406  * early so that node->attributes() and default links do not see them.
407  */
408 void ShaderGraph::remove_proxy_nodes()
409 {
410         vector<bool> removed(num_node_ids, false);
411         bool any_node_removed = false;
412
413         foreach(ShaderNode *node, nodes) {
414                 if(node->special_type == SHADER_SPECIAL_TYPE_PROXY) {
415                         ConvertNode *proxy = static_cast<ConvertNode*>(node);
416                         ShaderInput *input = proxy->inputs[0];
417                         ShaderOutput *output = proxy->outputs[0];
418
419                         /* bypass the proxy node */
420                         if(input->link) {
421                                 relink(proxy, output, input->link);
422                         }
423                         else {
424                                 /* Copy because disconnect modifies this list */
425                                 vector<ShaderInput*> links(output->links);
426
427                                 foreach(ShaderInput *to, links) {
428                                         /* remove any autoconvert nodes too if they lead to
429                                          * sockets with an automatically set default value */
430                                         ShaderNode *tonode = to->parent;
431
432                                         if(tonode->special_type == SHADER_SPECIAL_TYPE_AUTOCONVERT) {
433                                                 bool all_links_removed = true;
434                                                 vector<ShaderInput*> links = tonode->outputs[0]->links;
435
436                                                 foreach(ShaderInput *autoin, links) {
437                                                         if(autoin->flags() & SocketType::DEFAULT_LINK_MASK)
438                                                                 disconnect(autoin);
439                                                         else
440                                                                 all_links_removed = false;
441                                                 }
442
443                                                 if(all_links_removed)
444                                                         removed[tonode->id] = true;
445                                         }
446
447                                         disconnect(to);
448
449                                         /* transfer the default input value to the target socket */
450                                         tonode->copy_value(to->socket_type, *proxy, input->socket_type);
451                                 }
452                         }
453
454                         removed[proxy->id] = true;
455                         any_node_removed = true;
456                 }
457         }
458
459         /* remove nodes */
460         if(any_node_removed) {
461                 list<ShaderNode*> newnodes;
462
463                 foreach(ShaderNode *node, nodes) {
464                         if(!removed[node->id])
465                                 newnodes.push_back(node);
466                         else
467                                 delete node;
468                 }
469
470                 nodes = newnodes;
471         }
472 }
473
474 /* Constant folding.
475  *
476  * Try to constant fold some nodes, and pipe result directly to
477  * the input socket of connected nodes.
478  */
479 void ShaderGraph::constant_fold()
480 {
481         ShaderNodeSet done, scheduled;
482         queue<ShaderNode*> traverse_queue;
483
484         bool has_displacement = (output()->input("Displacement")->link != NULL);
485
486         /* Schedule nodes which doesn't have any dependencies. */
487         foreach(ShaderNode *node, nodes) {
488                 if(!check_node_inputs_has_links(node)) {
489                         traverse_queue.push(node);
490                         scheduled.insert(node);
491                 }
492         }
493
494         while(!traverse_queue.empty()) {
495                 ShaderNode *node = traverse_queue.front();
496                 traverse_queue.pop();
497                 done.insert(node);
498                 foreach(ShaderOutput *output, node->outputs) {
499                         if(output->links.size() == 0) {
500                                 continue;
501                         }
502                         /* Schedule node which was depending on the value,
503                          * when possible. Do it before disconnect.
504                          */
505                         foreach(ShaderInput *input, output->links) {
506                                 if(scheduled.find(input->parent) != scheduled.end()) {
507                                         /* Node might not be optimized yet but scheduled already
508                                          * by other dependencies. No need to re-schedule it.
509                                          */
510                                         continue;
511                                 }
512                                 /* Schedule node if its inputs are fully done. */
513                                 if(check_node_inputs_traversed(input->parent, done)) {
514                                         traverse_queue.push(input->parent);
515                                         scheduled.insert(input->parent);
516                                 }
517                         }
518                         /* Optimize current node. */
519                         ConstantFolder folder(this, node, output);
520                         node->constant_fold(folder);
521                 }
522         }
523
524         /* Folding might have removed all nodes connected to the displacement output
525          * even tho there is displacement to be applied, so add in a value node if
526          * that happens to ensure there is still a valid graph for displacement.
527          */
528         if(has_displacement && !output()->input("Displacement")->link) {
529                 ValueNode *value = (ValueNode*)add(new ValueNode());
530                 value->value = output()->displacement;
531
532                 connect(value->output("Value"), output()->input("Displacement"));
533         }
534 }
535
536 /* Simplification. */
537 void ShaderGraph::simplify_settings(Scene *scene)
538 {
539         foreach(ShaderNode *node, nodes) {
540                 node->simplify_settings(scene);
541         }
542 }
543
544 /* Deduplicate nodes with same settings. */
545 void ShaderGraph::deduplicate_nodes()
546 {
547         /* NOTES:
548          * - Deduplication happens for nodes which has same exact settings and same
549          *   exact input links configuration (either connected to same output or has
550          *   the same exact default value).
551          * - Deduplication happens in the bottom-top manner, so we know for fact that
552          *   all traversed nodes are either can not be deduplicated at all or were
553          *   already deduplicated.
554          */
555
556         ShaderNodeSet scheduled, done;
557         map<ustring, ShaderNodeSet> candidates;
558         queue<ShaderNode*> traverse_queue;
559         int num_deduplicated = 0;
560
561         /* Schedule nodes which doesn't have any dependencies. */
562         foreach(ShaderNode *node, nodes) {
563                 if(!check_node_inputs_has_links(node)) {
564                         traverse_queue.push(node);
565                         scheduled.insert(node);
566                 }
567         }
568
569         while(!traverse_queue.empty()) {
570                 ShaderNode *node = traverse_queue.front();
571                 traverse_queue.pop();
572                 done.insert(node);
573                 /* Schedule the nodes which were depending on the current node. */
574                 bool has_output_links = false;
575                 foreach(ShaderOutput *output, node->outputs) {
576                         foreach(ShaderInput *input, output->links) {
577                                 has_output_links = true;
578                                 if(scheduled.find(input->parent) != scheduled.end()) {
579                                         /* Node might not be optimized yet but scheduled already
580                                          * by other dependencies. No need to re-schedule it.
581                                          */
582                                         continue;
583                                 }
584                                 /* Schedule node if its inputs are fully done. */
585                                 if(check_node_inputs_traversed(input->parent, done)) {
586                                         traverse_queue.push(input->parent);
587                                         scheduled.insert(input->parent);
588                                 }
589                         }
590                 }
591                 /* Only need to care about nodes that are actually used */
592                 if(!has_output_links) {
593                         continue;
594                 }
595                 /* Try to merge this node with another one. */
596                 ShaderNode *merge_with = NULL;
597                 foreach(ShaderNode *other_node, candidates[node->type->name]) {
598                         if(node != other_node && node->equals(*other_node)) {
599                                 merge_with = other_node;
600                                 break;
601                         }
602                 }
603                 /* If found an equivalent, merge; otherwise keep node for later merges */
604                 if(merge_with != NULL) {
605                         for(int i = 0; i < node->outputs.size(); ++i) {
606                                 relink(node, node->outputs[i], merge_with->outputs[i]);
607                         }
608                         num_deduplicated++;
609                 }
610                 else {
611                         candidates[node->type->name].insert(node);
612                 }
613         }
614
615         if(num_deduplicated > 0) {
616                 VLOG(1) << "Deduplicated " << num_deduplicated << " nodes.";
617         }
618 }
619
620 /* Check whether volume output has meaningful nodes, otherwise
621  * disconnect the output.
622  */
623 void ShaderGraph::verify_volume_output()
624 {
625         /* Check whether we can optimize the whole volume graph out. */
626         ShaderInput *volume_in = output()->input("Volume");
627         if(volume_in->link == NULL) {
628                 return;
629         }
630         bool has_valid_volume = false;
631         ShaderNodeSet scheduled;
632         queue<ShaderNode*> traverse_queue;
633         /* Schedule volume output. */
634         traverse_queue.push(volume_in->link->parent);
635         scheduled.insert(volume_in->link->parent);
636         /* Traverse down the tree. */
637         while(!traverse_queue.empty()) {
638                 ShaderNode *node = traverse_queue.front();
639                 traverse_queue.pop();
640                 /* Node is fully valid for volume, can't optimize anything out. */
641                 if(node->has_volume_support()) {
642                         has_valid_volume = true;
643                         break;
644                 }
645                 foreach(ShaderInput *input, node->inputs) {
646                         if(input->link == NULL) {
647                                 continue;
648                         }
649                         if(scheduled.find(input->link->parent) != scheduled.end()) {
650                                 continue;
651                         }
652                         traverse_queue.push(input->link->parent);
653                         scheduled.insert(input->link->parent);
654                 }
655         }
656         if(!has_valid_volume) {
657                 VLOG(1) << "Disconnect meaningless volume output.";
658                 disconnect(volume_in->link);
659         }
660 }
661
662 void ShaderGraph::break_cycles(ShaderNode *node, vector<bool>& visited, vector<bool>& on_stack)
663 {
664         visited[node->id] = true;
665         on_stack[node->id] = true;
666
667         foreach(ShaderInput *input, node->inputs) {
668                 if(input->link) {
669                         ShaderNode *depnode = input->link->parent;
670
671                         if(on_stack[depnode->id]) {
672                                 /* break cycle */
673                                 disconnect(input);
674                                 fprintf(stderr, "Cycles shader graph: detected cycle in graph, connection removed.\n");
675                         }
676                         else if(!visited[depnode->id]) {
677                                 /* visit dependencies */
678                                 break_cycles(depnode, visited, on_stack);
679                         }
680                 }
681         }
682
683         on_stack[node->id] = false;
684 }
685
686 void ShaderGraph::clean(Scene *scene)
687 {
688         /* Graph simplification */
689
690         /* NOTE: Remove proxy nodes was already done. */
691         constant_fold();
692         simplify_settings(scene);
693         deduplicate_nodes();
694         verify_volume_output();
695
696         /* we do two things here: find cycles and break them, and remove unused
697          * nodes that don't feed into the output. how cycles are broken is
698          * undefined, they are invalid input, the important thing is to not crash */
699
700         vector<bool> visited(num_node_ids, false);
701         vector<bool> on_stack(num_node_ids, false);
702         
703         /* break cycles */
704         break_cycles(output(), visited, on_stack);
705
706         /* disconnect unused nodes */
707         foreach(ShaderNode *node, nodes) {
708                 if(!visited[node->id]) {
709                         foreach(ShaderInput *to, node->inputs) {
710                                 ShaderOutput *from = to->link;
711
712                                 if(from) {
713                                         to->link = NULL;
714                                         from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
715                                 }
716                         }
717                 }
718         }
719
720         /* remove unused nodes */
721         list<ShaderNode*> newnodes;
722
723         foreach(ShaderNode *node, nodes) {
724                 if(visited[node->id])
725                         newnodes.push_back(node);
726                 else
727                         delete node;
728         }
729
730         nodes = newnodes;
731 }
732
733 void ShaderGraph::default_inputs(bool do_osl)
734 {
735         /* nodes can specify default texture coordinates, for now we give
736          * everything the position by default, except for the sky texture */
737
738         ShaderNode *geom = NULL;
739         ShaderNode *texco = NULL;
740
741         foreach(ShaderNode *node, nodes) {
742                 foreach(ShaderInput *input, node->inputs) {
743                         if(!input->link && (!(input->flags() & SocketType::OSL_INTERNAL) || do_osl)) {
744                                 if(input->flags() & SocketType::LINK_TEXTURE_GENERATED) {
745                                         if(!texco)
746                                                 texco = new TextureCoordinateNode();
747
748                                         connect(texco->output("Generated"), input);
749                                 }
750                                 else if(input->flags() & SocketType::LINK_TEXTURE_UV) {
751                                         if(!texco)
752                                                 texco = new TextureCoordinateNode();
753
754                                         connect(texco->output("UV"), input);
755                                 }
756                                 else if(input->flags() & SocketType::LINK_INCOMING) {
757                                         if(!geom)
758                                                 geom = new GeometryNode();
759
760                                         connect(geom->output("Incoming"), input);
761                                 }
762                                 else if(input->flags() & SocketType::LINK_NORMAL) {
763                                         if(!geom)
764                                                 geom = new GeometryNode();
765
766                                         connect(geom->output("Normal"), input);
767                                 }
768                                 else if(input->flags() & SocketType::LINK_POSITION) {
769                                         if(!geom)
770                                                 geom = new GeometryNode();
771
772                                         connect(geom->output("Position"), input);
773                                 }
774                                 else if(input->flags() & SocketType::LINK_TANGENT) {
775                                         if(!geom)
776                                                 geom = new GeometryNode();
777
778                                         connect(geom->output("Tangent"), input);
779                                 }
780                         }
781                 }
782         }
783
784         if(geom)
785                 add(geom);
786         if(texco)
787                 add(texco);
788 }
789
790 void ShaderGraph::refine_bump_nodes()
791 {
792         /* we transverse the node graph looking for bump nodes, when we find them,
793          * like in bump_from_displacement(), we copy the sub-graph defined from "bump"
794          * input to the inputs "center","dx" and "dy" What is in "bump" input is moved
795          * to "center" input. */
796
797         foreach(ShaderNode *node, nodes) {
798                 if(node->special_type == SHADER_SPECIAL_TYPE_BUMP && node->input("Height")->link) {
799                         ShaderInput *bump_input = node->input("Height");
800                         ShaderNodeSet nodes_bump;
801
802                         /* make 2 extra copies of the subgraph defined in Bump input */
803                         ShaderNodeMap nodes_dx;
804                         ShaderNodeMap nodes_dy;
805
806                         /* find dependencies for the given input */
807                         find_dependencies(nodes_bump, bump_input);
808
809                         copy_nodes(nodes_bump, nodes_dx);
810                         copy_nodes(nodes_bump, nodes_dy);
811         
812                         /* mark nodes to indicate they are use for bump computation, so
813                            that any texture coordinates are shifted by dx/dy when sampling */
814                         foreach(ShaderNode *node, nodes_bump)
815                                 node->bump = SHADER_BUMP_CENTER;
816                         foreach(NodePair& pair, nodes_dx)
817                                 pair.second->bump = SHADER_BUMP_DX;
818                         foreach(NodePair& pair, nodes_dy)
819                                 pair.second->bump = SHADER_BUMP_DY;
820
821                         ShaderOutput *out = bump_input->link;
822                         ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name());
823                         ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name());
824
825                         connect(out_dx, node->input("SampleX"));
826                         connect(out_dy, node->input("SampleY"));
827                         
828                         /* add generated nodes */
829                         foreach(NodePair& pair, nodes_dx)
830                                 add(pair.second);
831                         foreach(NodePair& pair, nodes_dy)
832                                 add(pair.second);
833                         
834                         /* connect what is connected is bump to samplecenter input*/
835                         connect(out , node->input("SampleCenter"));
836
837                         /* bump input is just for connectivity purpose for the graph input,
838                          * we re-connected this input to samplecenter, so lets disconnect it
839                          * from bump input */
840                         disconnect(bump_input);
841                 }
842         }
843 }
844
845 void ShaderGraph::bump_from_displacement(bool use_object_space)
846 {
847         /* generate bump mapping automatically from displacement. bump mapping is
848          * done using a 3-tap filter, computing the displacement at the center,
849          * and two other positions shifted by ray differentials.
850          *
851          * since the input to displacement is a node graph, we need to ensure that
852          * all texture coordinates use are shift by the ray differentials. for this
853          * reason we make 3 copies of the node subgraph defining the displacement,
854          * with each different geometry and texture coordinate nodes that generate
855          * different shifted coordinates.
856          *
857          * these 3 displacement values are then fed into the bump node, which will
858          * output the perturbed normal. */
859
860         ShaderInput *displacement_in = output()->input("Displacement");
861
862         if(!displacement_in->link)
863                 return;
864         
865         /* find dependencies for the given input */
866         ShaderNodeSet nodes_displace;
867         find_dependencies(nodes_displace, displacement_in);
868
869         /* copy nodes for 3 bump samples */
870         ShaderNodeMap nodes_center;
871         ShaderNodeMap nodes_dx;
872         ShaderNodeMap nodes_dy;
873
874         copy_nodes(nodes_displace, nodes_center);
875         copy_nodes(nodes_displace, nodes_dx);
876         copy_nodes(nodes_displace, nodes_dy);
877
878         /* mark nodes to indicate they are use for bump computation, so
879          * that any texture coordinates are shifted by dx/dy when sampling */
880         foreach(NodePair& pair, nodes_center)
881                 pair.second->bump = SHADER_BUMP_CENTER;
882         foreach(NodePair& pair, nodes_dx)
883                 pair.second->bump = SHADER_BUMP_DX;
884         foreach(NodePair& pair, nodes_dy)
885                 pair.second->bump = SHADER_BUMP_DY;
886
887         /* add set normal node and connect the bump normal ouput to the set normal
888          * output, so it can finally set the shader normal, note we are only doing
889          * this for bump from displacement, this will be the only bump allowed to
890          * overwrite the shader normal */
891         ShaderNode *set_normal = add(new SetNormalNode());
892         
893         /* add bump node and connect copied graphs to it */
894         BumpNode *bump = (BumpNode*)add(new BumpNode());
895         bump->use_object_space = use_object_space;
896
897         ShaderOutput *out = displacement_in->link;
898         ShaderOutput *out_center = nodes_center[out->parent]->output(out->name());
899         ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name());
900         ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name());
901
902         connect(out_center, bump->input("SampleCenter"));
903         connect(out_dx, bump->input("SampleX"));
904         connect(out_dy, bump->input("SampleY"));
905         
906         /* connect the bump out to the set normal in: */
907         connect(bump->output("Normal"), set_normal->input("Direction"));
908
909         /* connect to output node */
910         connect(set_normal->output("Normal"), output()->input("Normal"));
911
912         /* finally, add the copied nodes to the graph. we can't do this earlier
913          * because we would create dependency cycles in the above loop */
914         foreach(NodePair& pair, nodes_center)
915                 add(pair.second);
916         foreach(NodePair& pair, nodes_dx)
917                 add(pair.second);
918         foreach(NodePair& pair, nodes_dy)
919                 add(pair.second);
920 }
921
922 void ShaderGraph::transform_multi_closure(ShaderNode *node, ShaderOutput *weight_out, bool volume)
923 {
924         /* for SVM in multi closure mode, this transforms the shader mix/add part of
925          * the graph into nodes that feed weights into closure nodes. this is too
926          * avoid building a closure tree and then flattening it, and instead write it
927          * directly to an array */
928         
929         if(node->special_type == SHADER_SPECIAL_TYPE_COMBINE_CLOSURE) {
930                 ShaderInput *fin = node->input("Fac");
931                 ShaderInput *cl1in = node->input("Closure1");
932                 ShaderInput *cl2in = node->input("Closure2");
933                 ShaderOutput *weight1_out, *weight2_out;
934
935                 if(fin) {
936                         /* mix closure: add node to mix closure weights */
937                         MixClosureWeightNode *mix_node = new MixClosureWeightNode();
938                         add(mix_node);
939                         ShaderInput *fac_in = mix_node->input("Fac"); 
940                         ShaderInput *weight_in = mix_node->input("Weight"); 
941
942                         if(fin->link)
943                                 connect(fin->link, fac_in);
944                         else
945                                 mix_node->fac = node->get_float(fin->socket_type);
946
947                         if(weight_out)
948                                 connect(weight_out, weight_in);
949
950                         weight1_out = mix_node->output("Weight1");
951                         weight2_out = mix_node->output("Weight2");
952                 }
953                 else {
954                         /* add closure: just pass on any weights */
955                         weight1_out = weight_out;
956                         weight2_out = weight_out;
957                 }
958
959                 if(cl1in->link)
960                         transform_multi_closure(cl1in->link->parent, weight1_out, volume);
961                 if(cl2in->link)
962                         transform_multi_closure(cl2in->link->parent, weight2_out, volume);
963         }
964         else {
965                 ShaderInput *weight_in = node->input((volume)? "VolumeMixWeight": "SurfaceMixWeight");
966
967                 /* not a closure node? */
968                 if(!weight_in)
969                         return;
970
971                 /* already has a weight connected to it? add weights */
972                 float weight_value = node->get_float(weight_in->socket_type);
973                 if(weight_in->link || weight_value != 0.0f) {
974                         MathNode *math_node = new MathNode();
975                         add(math_node);
976
977                         if(weight_in->link)
978                                 connect(weight_in->link, math_node->input("Value1"));
979                         else
980                                 math_node->value1 = weight_value;
981
982                         if(weight_out)
983                                 connect(weight_out, math_node->input("Value2"));
984                         else
985                                 math_node->value2 = 1.0f;
986
987                         weight_out = math_node->output("Value");
988                         if(weight_in->link)
989                                 disconnect(weight_in);
990                 }
991
992                 /* connected to closure mix weight */
993                 if(weight_out)
994                         connect(weight_out, weight_in);
995                 else
996                         node->set(weight_in->socket_type, weight_value + 1.0f);
997         }
998 }
999
1000 int ShaderGraph::get_num_closures()
1001 {
1002         int num_closures = 0;
1003         foreach(ShaderNode *node, nodes) {
1004                 ClosureType closure_type = node->get_closure_type();
1005                 if(closure_type == CLOSURE_NONE_ID) {
1006                         continue;
1007                 }
1008                 else if(CLOSURE_IS_BSSRDF(closure_type)) {
1009                         num_closures += 3;
1010                 }
1011                 else if(CLOSURE_IS_GLASS(closure_type)) {
1012                         num_closures += 2;
1013                 }
1014                 else if(CLOSURE_IS_BSDF_MULTISCATTER(closure_type)) {
1015                         num_closures += 2;
1016                 }
1017                 else if(CLOSURE_IS_PRINCIPLED(closure_type)) {
1018                         num_closures += 8;
1019                 }
1020                 else if(CLOSURE_IS_VOLUME(closure_type)) {
1021                         num_closures += VOLUME_STACK_SIZE;
1022                 }
1023                 else {
1024                         ++num_closures;
1025                 }
1026         }
1027         return num_closures;
1028 }
1029
1030 void ShaderGraph::dump_graph(const char *filename)
1031 {
1032         FILE *fd = fopen(filename, "w");
1033
1034         if(fd == NULL) {
1035                 printf("Error opening file for dumping the graph: %s\n", filename);
1036                 return;
1037         }
1038
1039         fprintf(fd, "digraph shader_graph {\n");
1040         fprintf(fd, "ranksep=1.5\n");
1041         fprintf(fd, "rankdir=LR\n");
1042         fprintf(fd, "splines=false\n");
1043
1044         foreach(ShaderNode *node, nodes) {
1045                 fprintf(fd, "// NODE: %p\n", node);
1046                 fprintf(fd, "\"%p\" [shape=record,label=\"{", node);
1047                 if(node->inputs.size()) {
1048                         fprintf(fd, "{");
1049                         foreach(ShaderInput *socket, node->inputs) {
1050                                 if(socket != node->inputs[0]) {
1051                                         fprintf(fd, "|");
1052                                 }
1053                                 fprintf(fd, "<IN_%p>%s", socket, socket->name().c_str());
1054                         }
1055                         fprintf(fd, "}|");
1056                 }
1057                 fprintf(fd, "%s", node->name.c_str());
1058                 if(node->bump == SHADER_BUMP_CENTER) {
1059                         fprintf(fd, " (bump:center)");
1060                 }
1061                 else if(node->bump == SHADER_BUMP_DX) {
1062                         fprintf(fd, " (bump:dx)");
1063                 }
1064                 else if(node->bump == SHADER_BUMP_DY) {
1065                         fprintf(fd, " (bump:dy)");
1066                 }
1067                 if(node->outputs.size()) {
1068                         fprintf(fd, "|{");
1069                         foreach(ShaderOutput *socket, node->outputs) {
1070                                 if(socket != node->outputs[0]) {
1071                                         fprintf(fd, "|");
1072                                 }
1073                                 fprintf(fd, "<OUT_%p>%s", socket, socket->name().c_str());
1074                         }
1075                         fprintf(fd, "}");
1076                 }
1077                 fprintf(fd, "}\"]");
1078         }
1079
1080         foreach(ShaderNode *node, nodes) {
1081                 foreach(ShaderOutput *output, node->outputs) {
1082                         foreach(ShaderInput *input, output->links) {
1083                                 fprintf(fd,
1084                                         "// CONNECTION: OUT_%p->IN_%p (%s:%s)\n",
1085                                         output,
1086                                         input,
1087                                         output->name().c_str(), input->name().c_str());
1088                                 fprintf(fd,
1089                                         "\"%p\":\"OUT_%p\":e -> \"%p\":\"IN_%p\":w [label=\"\"]\n",
1090                                         output->parent,
1091                                         output,
1092                                         input->parent,
1093                                         input);
1094                         }
1095                 }
1096         }
1097
1098         fprintf(fd, "}\n");
1099         fclose(fd);
1100 }
1101
1102 CCL_NAMESPACE_END
1103