Cycles: Query XYZ to/from Scene Linear conversion from OCIO instead of assuming sRGB
[blender.git] / intern / cycles / render / graph.cpp
1 /*
2  * Copyright 2011-2016 Blender Foundation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "render/attribute.h"
18 #include "render/graph.h"
19 #include "render/nodes.h"
20 #include "render/scene.h"
21 #include "render/shader.h"
22 #include "render/constant_fold.h"
23
24 #include "util/util_algorithm.h"
25 #include "util/util_foreach.h"
26 #include "util/util_logging.h"
27 #include "util/util_md5.h"
28 #include "util/util_queue.h"
29
30 CCL_NAMESPACE_BEGIN
31
32 namespace {
33
34 bool check_node_inputs_has_links(const ShaderNode *node)
35 {
36         foreach(const ShaderInput *in, node->inputs) {
37                 if(in->link) {
38                         return true;
39                 }
40         }
41         return false;
42 }
43
44 bool check_node_inputs_traversed(const ShaderNode *node,
45                                  const ShaderNodeSet& done)
46 {
47         foreach(const ShaderInput *in, node->inputs) {
48                 if(in->link) {
49                         if(done.find(in->link->parent) == done.end()) {
50                                 return false;
51                         }
52                 }
53         }
54         return true;
55 }
56
57 }  /* namespace */
58
59 /* Node */
60
61 ShaderNode::ShaderNode(const NodeType *type)
62 : Node(type)
63 {
64         name = type->name;
65         id = -1;
66         bump = SHADER_BUMP_NONE;
67         special_type = SHADER_SPECIAL_TYPE_NONE;
68
69         create_inputs_outputs(type);
70 }
71
72 ShaderNode::~ShaderNode()
73 {
74         foreach(ShaderInput *socket, inputs)
75                 delete socket;
76
77         foreach(ShaderOutput *socket, outputs)
78                 delete socket;
79 }
80
81 void ShaderNode::create_inputs_outputs(const NodeType *type)
82 {
83         foreach(const SocketType& socket, type->inputs) {
84                 if(socket.flags & SocketType::LINKABLE) {
85                         inputs.push_back(new ShaderInput(socket, this));
86                 }
87         }
88
89         foreach(const SocketType& socket, type->outputs) {
90                 outputs.push_back(new ShaderOutput(socket, this));
91         }
92 }
93
94 ShaderInput *ShaderNode::input(const char *name)
95 {
96         foreach(ShaderInput *socket, inputs) {
97                 if(socket->name() == name)
98                         return socket;
99         }
100
101         return NULL;
102 }
103
104 ShaderOutput *ShaderNode::output(const char *name)
105 {
106         foreach(ShaderOutput *socket, outputs)
107                 if(socket->name() == name)
108                         return socket;
109
110         return NULL;
111 }
112
113 ShaderInput *ShaderNode::input(ustring name)
114 {
115         foreach(ShaderInput *socket, inputs) {
116                 if(socket->name() == name)
117                         return socket;
118         }
119
120         return NULL;
121 }
122
123 ShaderOutput *ShaderNode::output(ustring name)
124 {
125         foreach(ShaderOutput *socket, outputs)
126                 if(socket->name() == name)
127                         return socket;
128
129         return NULL;
130 }
131
132 void ShaderNode::attributes(Shader *shader, AttributeRequestSet *attributes)
133 {
134         foreach(ShaderInput *input, inputs) {
135                 if(!input->link) {
136                         if(input->flags() & SocketType::LINK_TEXTURE_GENERATED) {
137                                 if(shader->has_surface)
138                                         attributes->add(ATTR_STD_GENERATED);
139                                 if(shader->has_volume)
140                                         attributes->add(ATTR_STD_GENERATED_TRANSFORM);
141                         }
142                         else if(input->flags() & SocketType::LINK_TEXTURE_UV) {
143                                 if(shader->has_surface)
144                                         attributes->add(ATTR_STD_UV);
145                         }
146                 }
147         }
148 }
149
150 bool ShaderNode::equals(const ShaderNode& other)
151 {
152         if(type != other.type || bump != other.bump) {
153                 return false;
154         }
155
156         assert(inputs.size() == other.inputs.size());
157
158         /* Compare unlinkable sockets */
159         foreach(const SocketType& socket, type->inputs) {
160                 if(!(socket.flags & SocketType::LINKABLE)) {
161                         if(!Node::equals_value(other, socket)) {
162                                 return false;
163                         }
164                 }
165         }
166
167         /* Compare linkable input sockets */
168         for(int i = 0; i < inputs.size(); ++i) {
169                 ShaderInput *input_a = inputs[i],
170                             *input_b = other.inputs[i];
171                 if(input_a->link == NULL && input_b->link == NULL) {
172                         /* Unconnected inputs are expected to have the same value. */
173                         if(!Node::equals_value(other, input_a->socket_type)) {
174                                 return false;
175                         }
176                 }
177                 else if(input_a->link != NULL && input_b->link != NULL) {
178                         /* Expect links are to come from the same exact socket. */
179                         if(input_a->link != input_b->link) {
180                                 return false;
181                         }
182                 }
183                 else {
184                         /* One socket has a link and another has not, inputs can't be
185                          * considered equal.
186                          */
187                         return false;
188                 }
189         }
190
191         return true;
192 }
193
194 /* Graph */
195
196 ShaderGraph::ShaderGraph()
197 {
198         finalized = false;
199         simplified = false;
200         num_node_ids = 0;
201         add(new OutputNode());
202 }
203
204 ShaderGraph::~ShaderGraph()
205 {
206         clear_nodes();
207 }
208
209 ShaderNode *ShaderGraph::add(ShaderNode *node)
210 {
211         assert(!finalized);
212         simplified = false;
213
214         node->id = num_node_ids++;
215         nodes.push_back(node);
216         return node;
217 }
218
219 OutputNode *ShaderGraph::output()
220 {
221         return (OutputNode*)nodes.front();
222 }
223
224 void ShaderGraph::connect(ShaderOutput *from, ShaderInput *to)
225 {
226         assert(!finalized);
227         assert(from && to);
228
229         if(to->link) {
230                 fprintf(stderr, "Cycles shader graph connect: input already connected.\n");
231                 return;
232         }
233
234         if(from->type() != to->type()) {
235                 /* can't do automatic conversion from closure */
236                 if(from->type() == SocketType::CLOSURE) {
237                         fprintf(stderr, "Cycles shader graph connect: can only connect closure to closure "
238                                 "(%s.%s to %s.%s).\n",
239                                 from->parent->name.c_str(), from->name().c_str(),
240                                 to->parent->name.c_str(), to->name().c_str());
241                         return;
242                 }
243
244                 /* add automatic conversion node in case of type mismatch */
245                 ShaderNode *convert;
246                 ShaderInput *convert_in;
247
248                 if (to->type() == SocketType::CLOSURE) {
249                         EmissionNode *emission = new EmissionNode();
250                         emission->color = make_float3(1.0f, 1.0f, 1.0f);
251                         emission->strength = 1.0f;
252                         convert = add(emission);
253                         /* Connect float inputs to Strength to save an additional Falue->Color conversion. */
254                         if(from->type() == SocketType::FLOAT) {
255                                 convert_in = convert->input("Strength");
256                         }
257                         else {
258                                 convert_in = convert->input("Color");
259                         }
260                 }
261                 else {
262                         convert = add(new ConvertNode(from->type(), to->type(), true));
263                         convert_in = convert->inputs[0];
264                 }
265
266                 connect(from, convert_in);
267                 connect(convert->outputs[0], to);
268         }
269         else {
270                 /* types match, just connect */
271                 to->link = from;
272                 from->links.push_back(to);
273         }
274 }
275
276 void ShaderGraph::disconnect(ShaderOutput *from)
277 {
278         assert(!finalized);
279         simplified = false;
280
281         foreach(ShaderInput *sock, from->links) {
282                 sock->link = NULL;
283         }
284
285         from->links.clear();
286 }
287
288 void ShaderGraph::disconnect(ShaderInput *to)
289 {
290         assert(!finalized);
291         assert(to->link);
292         simplified = false;
293
294         ShaderOutput *from = to->link;
295
296         to->link = NULL;
297         from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
298 }
299
300 void ShaderGraph::relink(ShaderNode *node, ShaderOutput *from, ShaderOutput *to)
301 {
302         simplified = false;
303
304         /* Copy because disconnect modifies this list */
305         vector<ShaderInput*> outputs = from->links;
306
307         /* Bypass node by moving all links from "from" to "to" */
308         foreach(ShaderInput *sock, node->inputs) {
309                 if(sock->link)
310                         disconnect(sock);
311         }
312
313         foreach(ShaderInput *sock, outputs) {
314                 disconnect(sock);
315                 if(to)
316                         connect(to, sock);
317         }
318 }
319
320 void ShaderGraph::simplify(Scene *scene)
321 {
322         if(!simplified) {
323                 default_inputs(scene->shader_manager->use_osl());
324                 clean(scene);
325                 refine_bump_nodes();
326
327                 simplified = true;
328         }
329 }
330
331 void ShaderGraph::finalize(Scene *scene,
332                            bool do_bump,
333                            bool do_simplify,
334                            bool bump_in_object_space)
335 {
336         /* before compiling, the shader graph may undergo a number of modifications.
337          * currently we set default geometry shader inputs, and create automatic bump
338          * from displacement. a graph can be finalized only once, and should not be
339          * modified afterwards. */
340
341         if(!finalized) {
342                 simplify(scene);
343
344                 if(do_bump)
345                         bump_from_displacement(bump_in_object_space);
346
347                 ShaderInput *surface_in = output()->input("Surface");
348                 ShaderInput *volume_in = output()->input("Volume");
349
350                 /* todo: make this work when surface and volume closures are tangled up */
351
352                 if(surface_in->link)
353                         transform_multi_closure(surface_in->link->parent, NULL, false);
354                 if(volume_in->link)
355                         transform_multi_closure(volume_in->link->parent, NULL, true);
356
357                 finalized = true;
358         }
359         else if(do_simplify) {
360                 simplify_settings(scene);
361         }
362 }
363
364 void ShaderGraph::find_dependencies(ShaderNodeSet& dependencies, ShaderInput *input)
365 {
366         /* find all nodes that this input depends on directly and indirectly */
367         ShaderNode *node = (input->link)? input->link->parent: NULL;
368
369         if(node != NULL && dependencies.find(node) == dependencies.end()) {
370                 foreach(ShaderInput *in, node->inputs)
371                         find_dependencies(dependencies, in);
372
373                 dependencies.insert(node);
374         }
375 }
376
377 void ShaderGraph::clear_nodes()
378 {
379         foreach(ShaderNode *node, nodes) {
380                 delete node;
381         }
382         nodes.clear();
383 }
384
385 void ShaderGraph::copy_nodes(ShaderNodeSet& nodes, ShaderNodeMap& nnodemap)
386 {
387         /* copy a set of nodes, and the links between them. the assumption is
388          * made that all nodes that inputs are linked to are in the set too. */
389
390         /* copy nodes */
391         foreach(ShaderNode *node, nodes) {
392                 ShaderNode *nnode = node->clone();
393                 nnodemap[node] = nnode;
394
395                 /* create new inputs and outputs to recreate links and ensure
396                  * that we still point to valid SocketType if the NodeType
397                  * changed in cloning, as it does for OSL nodes */
398                 nnode->inputs.clear();
399                 nnode->outputs.clear();
400                 nnode->create_inputs_outputs(nnode->type);
401         }
402
403         /* recreate links */
404         foreach(ShaderNode *node, nodes) {
405                 foreach(ShaderInput *input, node->inputs) {
406                         if(input->link) {
407                                 /* find new input and output */
408                                 ShaderNode *nfrom = nnodemap[input->link->parent];
409                                 ShaderNode *nto = nnodemap[input->parent];
410                                 ShaderOutput *noutput = nfrom->output(input->link->name());
411                                 ShaderInput *ninput = nto->input(input->name());
412
413                                 /* connect */
414                                 connect(noutput, ninput);
415                         }
416                 }
417         }
418 }
419
420 /* Graph simplification */
421 /* ******************** */
422
423 /* Remove proxy nodes.
424  *
425  * These only exists temporarily when exporting groups, and we must remove them
426  * early so that node->attributes() and default links do not see them.
427  */
428 void ShaderGraph::remove_proxy_nodes()
429 {
430         vector<bool> removed(num_node_ids, false);
431         bool any_node_removed = false;
432
433         foreach(ShaderNode *node, nodes) {
434                 if(node->special_type == SHADER_SPECIAL_TYPE_PROXY) {
435                         ConvertNode *proxy = static_cast<ConvertNode*>(node);
436                         ShaderInput *input = proxy->inputs[0];
437                         ShaderOutput *output = proxy->outputs[0];
438
439                         /* bypass the proxy node */
440                         if(input->link) {
441                                 relink(proxy, output, input->link);
442                         }
443                         else {
444                                 /* Copy because disconnect modifies this list */
445                                 vector<ShaderInput*> links(output->links);
446
447                                 foreach(ShaderInput *to, links) {
448                                         /* remove any autoconvert nodes too if they lead to
449                                          * sockets with an automatically set default value */
450                                         ShaderNode *tonode = to->parent;
451
452                                         if(tonode->special_type == SHADER_SPECIAL_TYPE_AUTOCONVERT) {
453                                                 bool all_links_removed = true;
454                                                 vector<ShaderInput*> links = tonode->outputs[0]->links;
455
456                                                 foreach(ShaderInput *autoin, links) {
457                                                         if(autoin->flags() & SocketType::DEFAULT_LINK_MASK)
458                                                                 disconnect(autoin);
459                                                         else
460                                                                 all_links_removed = false;
461                                                 }
462
463                                                 if(all_links_removed)
464                                                         removed[tonode->id] = true;
465                                         }
466
467                                         disconnect(to);
468
469                                         /* transfer the default input value to the target socket */
470                                         tonode->copy_value(to->socket_type, *proxy, input->socket_type);
471                                 }
472                         }
473
474                         removed[proxy->id] = true;
475                         any_node_removed = true;
476                 }
477         }
478
479         /* remove nodes */
480         if(any_node_removed) {
481                 list<ShaderNode*> newnodes;
482
483                 foreach(ShaderNode *node, nodes) {
484                         if(!removed[node->id])
485                                 newnodes.push_back(node);
486                         else
487                                 delete node;
488                 }
489
490                 nodes = newnodes;
491         }
492 }
493
494 /* Constant folding.
495  *
496  * Try to constant fold some nodes, and pipe result directly to
497  * the input socket of connected nodes.
498  */
499 void ShaderGraph::constant_fold(Scene *scene)
500 {
501         ShaderNodeSet done, scheduled;
502         queue<ShaderNode*> traverse_queue;
503
504         bool has_displacement = (output()->input("Displacement")->link != NULL);
505
506         /* Schedule nodes which doesn't have any dependencies. */
507         foreach(ShaderNode *node, nodes) {
508                 if(!check_node_inputs_has_links(node)) {
509                         traverse_queue.push(node);
510                         scheduled.insert(node);
511                 }
512         }
513
514         while(!traverse_queue.empty()) {
515                 ShaderNode *node = traverse_queue.front();
516                 traverse_queue.pop();
517                 done.insert(node);
518                 foreach(ShaderOutput *output, node->outputs) {
519                         if(output->links.size() == 0) {
520                                 continue;
521                         }
522                         /* Schedule node which was depending on the value,
523                          * when possible. Do it before disconnect.
524                          */
525                         foreach(ShaderInput *input, output->links) {
526                                 if(scheduled.find(input->parent) != scheduled.end()) {
527                                         /* Node might not be optimized yet but scheduled already
528                                          * by other dependencies. No need to re-schedule it.
529                                          */
530                                         continue;
531                                 }
532                                 /* Schedule node if its inputs are fully done. */
533                                 if(check_node_inputs_traversed(input->parent, done)) {
534                                         traverse_queue.push(input->parent);
535                                         scheduled.insert(input->parent);
536                                 }
537                         }
538                         /* Optimize current node. */
539                         ConstantFolder folder(this, node, output, scene);
540                         node->constant_fold(folder);
541                 }
542         }
543
544         /* Folding might have removed all nodes connected to the displacement output
545          * even tho there is displacement to be applied, so add in a value node if
546          * that happens to ensure there is still a valid graph for displacement.
547          */
548         if(has_displacement && !output()->input("Displacement")->link) {
549                 ColorNode *value = (ColorNode*)add(new ColorNode());
550                 value->value = output()->displacement;
551
552                 connect(value->output("Color"), output()->input("Displacement"));
553         }
554 }
555
556 /* Simplification. */
557 void ShaderGraph::simplify_settings(Scene *scene)
558 {
559         foreach(ShaderNode *node, nodes) {
560                 node->simplify_settings(scene);
561         }
562 }
563
564 /* Deduplicate nodes with same settings. */
565 void ShaderGraph::deduplicate_nodes()
566 {
567         /* NOTES:
568          * - Deduplication happens for nodes which has same exact settings and same
569          *   exact input links configuration (either connected to same output or has
570          *   the same exact default value).
571          * - Deduplication happens in the bottom-top manner, so we know for fact that
572          *   all traversed nodes are either can not be deduplicated at all or were
573          *   already deduplicated.
574          */
575
576         ShaderNodeSet scheduled, done;
577         map<ustring, ShaderNodeSet> candidates;
578         queue<ShaderNode*> traverse_queue;
579         int num_deduplicated = 0;
580
581         /* Schedule nodes which doesn't have any dependencies. */
582         foreach(ShaderNode *node, nodes) {
583                 if(!check_node_inputs_has_links(node)) {
584                         traverse_queue.push(node);
585                         scheduled.insert(node);
586                 }
587         }
588
589         while(!traverse_queue.empty()) {
590                 ShaderNode *node = traverse_queue.front();
591                 traverse_queue.pop();
592                 done.insert(node);
593                 /* Schedule the nodes which were depending on the current node. */
594                 bool has_output_links = false;
595                 foreach(ShaderOutput *output, node->outputs) {
596                         foreach(ShaderInput *input, output->links) {
597                                 has_output_links = true;
598                                 if(scheduled.find(input->parent) != scheduled.end()) {
599                                         /* Node might not be optimized yet but scheduled already
600                                          * by other dependencies. No need to re-schedule it.
601                                          */
602                                         continue;
603                                 }
604                                 /* Schedule node if its inputs are fully done. */
605                                 if(check_node_inputs_traversed(input->parent, done)) {
606                                         traverse_queue.push(input->parent);
607                                         scheduled.insert(input->parent);
608                                 }
609                         }
610                 }
611                 /* Only need to care about nodes that are actually used */
612                 if(!has_output_links) {
613                         continue;
614                 }
615                 /* Try to merge this node with another one. */
616                 ShaderNode *merge_with = NULL;
617                 foreach(ShaderNode *other_node, candidates[node->type->name]) {
618                         if(node != other_node && node->equals(*other_node)) {
619                                 merge_with = other_node;
620                                 break;
621                         }
622                 }
623                 /* If found an equivalent, merge; otherwise keep node for later merges */
624                 if(merge_with != NULL) {
625                         for(int i = 0; i < node->outputs.size(); ++i) {
626                                 relink(node, node->outputs[i], merge_with->outputs[i]);
627                         }
628                         num_deduplicated++;
629                 }
630                 else {
631                         candidates[node->type->name].insert(node);
632                 }
633         }
634
635         if(num_deduplicated > 0) {
636                 VLOG(1) << "Deduplicated " << num_deduplicated << " nodes.";
637         }
638 }
639
640 /* Check whether volume output has meaningful nodes, otherwise
641  * disconnect the output.
642  */
643 void ShaderGraph::verify_volume_output()
644 {
645         /* Check whether we can optimize the whole volume graph out. */
646         ShaderInput *volume_in = output()->input("Volume");
647         if(volume_in->link == NULL) {
648                 return;
649         }
650         bool has_valid_volume = false;
651         ShaderNodeSet scheduled;
652         queue<ShaderNode*> traverse_queue;
653         /* Schedule volume output. */
654         traverse_queue.push(volume_in->link->parent);
655         scheduled.insert(volume_in->link->parent);
656         /* Traverse down the tree. */
657         while(!traverse_queue.empty()) {
658                 ShaderNode *node = traverse_queue.front();
659                 traverse_queue.pop();
660                 /* Node is fully valid for volume, can't optimize anything out. */
661                 if(node->has_volume_support()) {
662                         has_valid_volume = true;
663                         break;
664                 }
665                 foreach(ShaderInput *input, node->inputs) {
666                         if(input->link == NULL) {
667                                 continue;
668                         }
669                         if(scheduled.find(input->link->parent) != scheduled.end()) {
670                                 continue;
671                         }
672                         traverse_queue.push(input->link->parent);
673                         scheduled.insert(input->link->parent);
674                 }
675         }
676         if(!has_valid_volume) {
677                 VLOG(1) << "Disconnect meaningless volume output.";
678                 disconnect(volume_in->link);
679         }
680 }
681
682 void ShaderGraph::break_cycles(ShaderNode *node, vector<bool>& visited, vector<bool>& on_stack)
683 {
684         visited[node->id] = true;
685         on_stack[node->id] = true;
686
687         foreach(ShaderInput *input, node->inputs) {
688                 if(input->link) {
689                         ShaderNode *depnode = input->link->parent;
690
691                         if(on_stack[depnode->id]) {
692                                 /* break cycle */
693                                 disconnect(input);
694                                 fprintf(stderr, "Cycles shader graph: detected cycle in graph, connection removed.\n");
695                         }
696                         else if(!visited[depnode->id]) {
697                                 /* visit dependencies */
698                                 break_cycles(depnode, visited, on_stack);
699                         }
700                 }
701         }
702
703         on_stack[node->id] = false;
704 }
705
706 void ShaderGraph::compute_displacement_hash()
707 {
708         /* Compute hash of all nodes linked to displacement, to detect if we need
709          * to recompute displacement when shader nodes change. */
710         ShaderInput *displacement_in = output()->input("Displacement");
711
712         if(!displacement_in->link) {
713                 displacement_hash = "";
714                 return;
715         }
716
717         ShaderNodeSet nodes_displace;
718         find_dependencies(nodes_displace, displacement_in);
719
720         MD5Hash md5;
721         foreach(ShaderNode *node, nodes_displace) {
722                 node->hash(md5);
723                 foreach(ShaderInput *input, node->inputs) {
724                         int link_id = (input->link) ? input->link->parent->id : 0;
725                         md5.append((uint8_t*)&link_id, sizeof(link_id));
726                 }
727         }
728
729         displacement_hash = md5.get_hex();
730 }
731
732 void ShaderGraph::clean(Scene *scene)
733 {
734         /* Graph simplification */
735
736         /* NOTE: Remove proxy nodes was already done. */
737         constant_fold(scene);
738         simplify_settings(scene);
739         deduplicate_nodes();
740         verify_volume_output();
741
742         /* we do two things here: find cycles and break them, and remove unused
743          * nodes that don't feed into the output. how cycles are broken is
744          * undefined, they are invalid input, the important thing is to not crash */
745
746         vector<bool> visited(num_node_ids, false);
747         vector<bool> on_stack(num_node_ids, false);
748         
749         /* break cycles */
750         break_cycles(output(), visited, on_stack);
751
752         /* disconnect unused nodes */
753         foreach(ShaderNode *node, nodes) {
754                 if(!visited[node->id]) {
755                         foreach(ShaderInput *to, node->inputs) {
756                                 ShaderOutput *from = to->link;
757
758                                 if(from) {
759                                         to->link = NULL;
760                                         from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
761                                 }
762                         }
763                 }
764         }
765
766         /* remove unused nodes */
767         list<ShaderNode*> newnodes;
768
769         foreach(ShaderNode *node, nodes) {
770                 if(visited[node->id])
771                         newnodes.push_back(node);
772                 else
773                         delete node;
774         }
775
776         nodes = newnodes;
777 }
778
779 void ShaderGraph::default_inputs(bool do_osl)
780 {
781         /* nodes can specify default texture coordinates, for now we give
782          * everything the position by default, except for the sky texture */
783
784         ShaderNode *geom = NULL;
785         ShaderNode *texco = NULL;
786
787         foreach(ShaderNode *node, nodes) {
788                 foreach(ShaderInput *input, node->inputs) {
789                         if(!input->link && (!(input->flags() & SocketType::OSL_INTERNAL) || do_osl)) {
790                                 if(input->flags() & SocketType::LINK_TEXTURE_GENERATED) {
791                                         if(!texco)
792                                                 texco = new TextureCoordinateNode();
793
794                                         connect(texco->output("Generated"), input);
795                                 }
796                                 if(input->flags() & SocketType::LINK_TEXTURE_NORMAL) {
797                                         if(!texco)
798                                                 texco = new TextureCoordinateNode();
799
800                                         connect(texco->output("Normal"), input);
801                                 }
802                                 else if(input->flags() & SocketType::LINK_TEXTURE_UV) {
803                                         if(!texco)
804                                                 texco = new TextureCoordinateNode();
805
806                                         connect(texco->output("UV"), input);
807                                 }
808                                 else if(input->flags() & SocketType::LINK_INCOMING) {
809                                         if(!geom)
810                                                 geom = new GeometryNode();
811
812                                         connect(geom->output("Incoming"), input);
813                                 }
814                                 else if(input->flags() & SocketType::LINK_NORMAL) {
815                                         if(!geom)
816                                                 geom = new GeometryNode();
817
818                                         connect(geom->output("Normal"), input);
819                                 }
820                                 else if(input->flags() & SocketType::LINK_POSITION) {
821                                         if(!geom)
822                                                 geom = new GeometryNode();
823
824                                         connect(geom->output("Position"), input);
825                                 }
826                                 else if(input->flags() & SocketType::LINK_TANGENT) {
827                                         if(!geom)
828                                                 geom = new GeometryNode();
829
830                                         connect(geom->output("Tangent"), input);
831                                 }
832                         }
833                 }
834         }
835
836         if(geom)
837                 add(geom);
838         if(texco)
839                 add(texco);
840 }
841
842 void ShaderGraph::refine_bump_nodes()
843 {
844         /* we transverse the node graph looking for bump nodes, when we find them,
845          * like in bump_from_displacement(), we copy the sub-graph defined from "bump"
846          * input to the inputs "center","dx" and "dy" What is in "bump" input is moved
847          * to "center" input. */
848
849         foreach(ShaderNode *node, nodes) {
850                 if(node->special_type == SHADER_SPECIAL_TYPE_BUMP && node->input("Height")->link) {
851                         ShaderInput *bump_input = node->input("Height");
852                         ShaderNodeSet nodes_bump;
853
854                         /* make 2 extra copies of the subgraph defined in Bump input */
855                         ShaderNodeMap nodes_dx;
856                         ShaderNodeMap nodes_dy;
857
858                         /* find dependencies for the given input */
859                         find_dependencies(nodes_bump, bump_input);
860
861                         copy_nodes(nodes_bump, nodes_dx);
862                         copy_nodes(nodes_bump, nodes_dy);
863         
864                         /* mark nodes to indicate they are use for bump computation, so
865                            that any texture coordinates are shifted by dx/dy when sampling */
866                         foreach(ShaderNode *node, nodes_bump)
867                                 node->bump = SHADER_BUMP_CENTER;
868                         foreach(NodePair& pair, nodes_dx)
869                                 pair.second->bump = SHADER_BUMP_DX;
870                         foreach(NodePair& pair, nodes_dy)
871                                 pair.second->bump = SHADER_BUMP_DY;
872
873                         ShaderOutput *out = bump_input->link;
874                         ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name());
875                         ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name());
876
877                         connect(out_dx, node->input("SampleX"));
878                         connect(out_dy, node->input("SampleY"));
879                         
880                         /* add generated nodes */
881                         foreach(NodePair& pair, nodes_dx)
882                                 add(pair.second);
883                         foreach(NodePair& pair, nodes_dy)
884                                 add(pair.second);
885                         
886                         /* connect what is connected is bump to samplecenter input*/
887                         connect(out , node->input("SampleCenter"));
888
889                         /* bump input is just for connectivity purpose for the graph input,
890                          * we re-connected this input to samplecenter, so lets disconnect it
891                          * from bump input */
892                         disconnect(bump_input);
893                 }
894         }
895 }
896
897 void ShaderGraph::bump_from_displacement(bool use_object_space)
898 {
899         /* generate bump mapping automatically from displacement. bump mapping is
900          * done using a 3-tap filter, computing the displacement at the center,
901          * and two other positions shifted by ray differentials.
902          *
903          * since the input to displacement is a node graph, we need to ensure that
904          * all texture coordinates use are shift by the ray differentials. for this
905          * reason we make 3 copies of the node subgraph defining the displacement,
906          * with each different geometry and texture coordinate nodes that generate
907          * different shifted coordinates.
908          *
909          * these 3 displacement values are then fed into the bump node, which will
910          * output the perturbed normal. */
911
912         ShaderInput *displacement_in = output()->input("Displacement");
913
914         if(!displacement_in->link)
915                 return;
916
917         /* find dependencies for the given input */
918         ShaderNodeSet nodes_displace;
919         find_dependencies(nodes_displace, displacement_in);
920
921         /* copy nodes for 3 bump samples */
922         ShaderNodeMap nodes_center;
923         ShaderNodeMap nodes_dx;
924         ShaderNodeMap nodes_dy;
925
926         copy_nodes(nodes_displace, nodes_center);
927         copy_nodes(nodes_displace, nodes_dx);
928         copy_nodes(nodes_displace, nodes_dy);
929
930         /* mark nodes to indicate they are use for bump computation, so
931          * that any texture coordinates are shifted by dx/dy when sampling */
932         foreach(NodePair& pair, nodes_center)
933                 pair.second->bump = SHADER_BUMP_CENTER;
934         foreach(NodePair& pair, nodes_dx)
935                 pair.second->bump = SHADER_BUMP_DX;
936         foreach(NodePair& pair, nodes_dy)
937                 pair.second->bump = SHADER_BUMP_DY;
938
939         /* add set normal node and connect the bump normal ouput to the set normal
940          * output, so it can finally set the shader normal, note we are only doing
941          * this for bump from displacement, this will be the only bump allowed to
942          * overwrite the shader normal */
943         ShaderNode *set_normal = add(new SetNormalNode());
944         
945         /* add bump node and connect copied graphs to it */
946         BumpNode *bump = (BumpNode*)add(new BumpNode());
947         bump->use_object_space = use_object_space;
948         bump->distance = 1.0f;
949
950         ShaderOutput *out = displacement_in->link;
951         ShaderOutput *out_center = nodes_center[out->parent]->output(out->name());
952         ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name());
953         ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name());
954
955         /* convert displacement vector to height */
956         VectorMathNode *dot_center = (VectorMathNode*)add(new VectorMathNode());
957         VectorMathNode *dot_dx = (VectorMathNode*)add(new VectorMathNode());
958         VectorMathNode *dot_dy = (VectorMathNode*)add(new VectorMathNode());
959
960         dot_center->type = NODE_VECTOR_MATH_DOT_PRODUCT;
961         dot_dx->type = NODE_VECTOR_MATH_DOT_PRODUCT;
962         dot_dy->type = NODE_VECTOR_MATH_DOT_PRODUCT;
963
964         GeometryNode *geom = (GeometryNode*)add(new GeometryNode());
965         connect(geom->output("Normal"), dot_center->input("Vector2"));
966         connect(geom->output("Normal"), dot_dx->input("Vector2"));
967         connect(geom->output("Normal"), dot_dy->input("Vector2"));
968
969         connect(out_center, dot_center->input("Vector1"));
970         connect(out_dx, dot_dx->input("Vector1"));
971         connect(out_dy, dot_dy->input("Vector1"));
972
973         connect(dot_center->output("Value"), bump->input("SampleCenter"));
974         connect(dot_dx->output("Value"), bump->input("SampleX"));
975         connect(dot_dy->output("Value"), bump->input("SampleY"));
976         
977         /* connect the bump out to the set normal in: */
978         connect(bump->output("Normal"), set_normal->input("Direction"));
979
980         /* connect to output node */
981         connect(set_normal->output("Normal"), output()->input("Normal"));
982
983         /* finally, add the copied nodes to the graph. we can't do this earlier
984          * because we would create dependency cycles in the above loop */
985         foreach(NodePair& pair, nodes_center)
986                 add(pair.second);
987         foreach(NodePair& pair, nodes_dx)
988                 add(pair.second);
989         foreach(NodePair& pair, nodes_dy)
990                 add(pair.second);
991 }
992
993 void ShaderGraph::transform_multi_closure(ShaderNode *node, ShaderOutput *weight_out, bool volume)
994 {
995         /* for SVM in multi closure mode, this transforms the shader mix/add part of
996          * the graph into nodes that feed weights into closure nodes. this is too
997          * avoid building a closure tree and then flattening it, and instead write it
998          * directly to an array */
999         
1000         if(node->special_type == SHADER_SPECIAL_TYPE_COMBINE_CLOSURE) {
1001                 ShaderInput *fin = node->input("Fac");
1002                 ShaderInput *cl1in = node->input("Closure1");
1003                 ShaderInput *cl2in = node->input("Closure2");
1004                 ShaderOutput *weight1_out, *weight2_out;
1005
1006                 if(fin) {
1007                         /* mix closure: add node to mix closure weights */
1008                         MixClosureWeightNode *mix_node = new MixClosureWeightNode();
1009                         add(mix_node);
1010                         ShaderInput *fac_in = mix_node->input("Fac"); 
1011                         ShaderInput *weight_in = mix_node->input("Weight"); 
1012
1013                         if(fin->link)
1014                                 connect(fin->link, fac_in);
1015                         else
1016                                 mix_node->fac = node->get_float(fin->socket_type);
1017
1018                         if(weight_out)
1019                                 connect(weight_out, weight_in);
1020
1021                         weight1_out = mix_node->output("Weight1");
1022                         weight2_out = mix_node->output("Weight2");
1023                 }
1024                 else {
1025                         /* add closure: just pass on any weights */
1026                         weight1_out = weight_out;
1027                         weight2_out = weight_out;
1028                 }
1029
1030                 if(cl1in->link)
1031                         transform_multi_closure(cl1in->link->parent, weight1_out, volume);
1032                 if(cl2in->link)
1033                         transform_multi_closure(cl2in->link->parent, weight2_out, volume);
1034         }
1035         else {
1036                 ShaderInput *weight_in = node->input((volume)? "VolumeMixWeight": "SurfaceMixWeight");
1037
1038                 /* not a closure node? */
1039                 if(!weight_in)
1040                         return;
1041
1042                 /* already has a weight connected to it? add weights */
1043                 float weight_value = node->get_float(weight_in->socket_type);
1044                 if(weight_in->link || weight_value != 0.0f) {
1045                         MathNode *math_node = new MathNode();
1046                         add(math_node);
1047
1048                         if(weight_in->link)
1049                                 connect(weight_in->link, math_node->input("Value1"));
1050                         else
1051                                 math_node->value1 = weight_value;
1052
1053                         if(weight_out)
1054                                 connect(weight_out, math_node->input("Value2"));
1055                         else
1056                                 math_node->value2 = 1.0f;
1057
1058                         weight_out = math_node->output("Value");
1059                         if(weight_in->link)
1060                                 disconnect(weight_in);
1061                 }
1062
1063                 /* connected to closure mix weight */
1064                 if(weight_out)
1065                         connect(weight_out, weight_in);
1066                 else
1067                         node->set(weight_in->socket_type, weight_value + 1.0f);
1068         }
1069 }
1070
1071 int ShaderGraph::get_num_closures()
1072 {
1073         int num_closures = 0;
1074         foreach(ShaderNode *node, nodes) {
1075                 ClosureType closure_type = node->get_closure_type();
1076                 if(closure_type == CLOSURE_NONE_ID) {
1077                         continue;
1078                 }
1079                 else if(CLOSURE_IS_BSSRDF(closure_type)) {
1080                         num_closures += 3;
1081                 }
1082                 else if(CLOSURE_IS_GLASS(closure_type)) {
1083                         num_closures += 2;
1084                 }
1085                 else if(CLOSURE_IS_BSDF_MULTISCATTER(closure_type)) {
1086                         num_closures += 2;
1087                 }
1088                 else if(CLOSURE_IS_PRINCIPLED(closure_type)) {
1089                         num_closures += 8;
1090                 }
1091                 else if(CLOSURE_IS_VOLUME(closure_type)) {
1092                         num_closures += VOLUME_STACK_SIZE;
1093                 }
1094                 else {
1095                         ++num_closures;
1096                 }
1097         }
1098         return num_closures;
1099 }
1100
1101 void ShaderGraph::dump_graph(const char *filename)
1102 {
1103         FILE *fd = fopen(filename, "w");
1104
1105         if(fd == NULL) {
1106                 printf("Error opening file for dumping the graph: %s\n", filename);
1107                 return;
1108         }
1109
1110         fprintf(fd, "digraph shader_graph {\n");
1111         fprintf(fd, "ranksep=1.5\n");
1112         fprintf(fd, "rankdir=LR\n");
1113         fprintf(fd, "splines=false\n");
1114
1115         foreach(ShaderNode *node, nodes) {
1116                 fprintf(fd, "// NODE: %p\n", node);
1117                 fprintf(fd, "\"%p\" [shape=record,label=\"{", node);
1118                 if(node->inputs.size()) {
1119                         fprintf(fd, "{");
1120                         foreach(ShaderInput *socket, node->inputs) {
1121                                 if(socket != node->inputs[0]) {
1122                                         fprintf(fd, "|");
1123                                 }
1124                                 fprintf(fd, "<IN_%p>%s", socket, socket->name().c_str());
1125                         }
1126                         fprintf(fd, "}|");
1127                 }
1128                 fprintf(fd, "%s", node->name.c_str());
1129                 if(node->bump == SHADER_BUMP_CENTER) {
1130                         fprintf(fd, " (bump:center)");
1131                 }
1132                 else if(node->bump == SHADER_BUMP_DX) {
1133                         fprintf(fd, " (bump:dx)");
1134                 }
1135                 else if(node->bump == SHADER_BUMP_DY) {
1136                         fprintf(fd, " (bump:dy)");
1137                 }
1138                 if(node->outputs.size()) {
1139                         fprintf(fd, "|{");
1140                         foreach(ShaderOutput *socket, node->outputs) {
1141                                 if(socket != node->outputs[0]) {
1142                                         fprintf(fd, "|");
1143                                 }
1144                                 fprintf(fd, "<OUT_%p>%s", socket, socket->name().c_str());
1145                         }
1146                         fprintf(fd, "}");
1147                 }
1148                 fprintf(fd, "}\"]");
1149         }
1150
1151         foreach(ShaderNode *node, nodes) {
1152                 foreach(ShaderOutput *output, node->outputs) {
1153                         foreach(ShaderInput *input, output->links) {
1154                                 fprintf(fd,
1155                                         "// CONNECTION: OUT_%p->IN_%p (%s:%s)\n",
1156                                         output,
1157                                         input,
1158                                         output->name().c_str(), input->name().c_str());
1159                                 fprintf(fd,
1160                                         "\"%p\":\"OUT_%p\":e -> \"%p\":\"IN_%p\":w [label=\"\"]\n",
1161                                         output->parent,
1162                                         output,
1163                                         input->parent,
1164                                         input);
1165                         }
1166                 }
1167         }
1168
1169         fprintf(fd, "}\n");
1170         fclose(fd);
1171 }
1172
1173 CCL_NAMESPACE_END
1174