Fix T63415: no Cycles displacement update when updating OSL code
[blender.git] / intern / cycles / render / graph.cpp
1 /*
2  * Copyright 2011-2016 Blender Foundation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "render/attribute.h"
18 #include "render/graph.h"
19 #include "render/nodes.h"
20 #include "render/scene.h"
21 #include "render/shader.h"
22 #include "render/constant_fold.h"
23
24 #include "util/util_algorithm.h"
25 #include "util/util_foreach.h"
26 #include "util/util_logging.h"
27 #include "util/util_md5.h"
28 #include "util/util_queue.h"
29
30 CCL_NAMESPACE_BEGIN
31
32 namespace {
33
34 bool check_node_inputs_has_links(const ShaderNode *node)
35 {
36   foreach (const ShaderInput *in, node->inputs) {
37     if (in->link) {
38       return true;
39     }
40   }
41   return false;
42 }
43
44 bool check_node_inputs_traversed(const ShaderNode *node, const ShaderNodeSet &done)
45 {
46   foreach (const ShaderInput *in, node->inputs) {
47     if (in->link) {
48       if (done.find(in->link->parent) == done.end()) {
49         return false;
50       }
51     }
52   }
53   return true;
54 }
55
56 } /* namespace */
57
58 /* Node */
59
60 ShaderNode::ShaderNode(const NodeType *type) : Node(type)
61 {
62   name = type->name;
63   id = -1;
64   bump = SHADER_BUMP_NONE;
65   special_type = SHADER_SPECIAL_TYPE_NONE;
66
67   create_inputs_outputs(type);
68 }
69
70 ShaderNode::~ShaderNode()
71 {
72   foreach (ShaderInput *socket, inputs)
73     delete socket;
74
75   foreach (ShaderOutput *socket, outputs)
76     delete socket;
77 }
78
79 void ShaderNode::create_inputs_outputs(const NodeType *type)
80 {
81   foreach (const SocketType &socket, type->inputs) {
82     if (socket.flags & SocketType::LINKABLE) {
83       inputs.push_back(new ShaderInput(socket, this));
84     }
85   }
86
87   foreach (const SocketType &socket, type->outputs) {
88     outputs.push_back(new ShaderOutput(socket, this));
89   }
90 }
91
92 ShaderInput *ShaderNode::input(const char *name)
93 {
94   foreach (ShaderInput *socket, inputs) {
95     if (socket->name() == name)
96       return socket;
97   }
98
99   return NULL;
100 }
101
102 ShaderOutput *ShaderNode::output(const char *name)
103 {
104   foreach (ShaderOutput *socket, outputs)
105     if (socket->name() == name)
106       return socket;
107
108   return NULL;
109 }
110
111 ShaderInput *ShaderNode::input(ustring name)
112 {
113   foreach (ShaderInput *socket, inputs) {
114     if (socket->name() == name)
115       return socket;
116   }
117
118   return NULL;
119 }
120
121 ShaderOutput *ShaderNode::output(ustring name)
122 {
123   foreach (ShaderOutput *socket, outputs)
124     if (socket->name() == name)
125       return socket;
126
127   return NULL;
128 }
129
130 void ShaderNode::attributes(Shader *shader, AttributeRequestSet *attributes)
131 {
132   foreach (ShaderInput *input, inputs) {
133     if (!input->link) {
134       if (input->flags() & SocketType::LINK_TEXTURE_GENERATED) {
135         if (shader->has_surface)
136           attributes->add(ATTR_STD_GENERATED);
137         if (shader->has_volume)
138           attributes->add(ATTR_STD_GENERATED_TRANSFORM);
139       }
140       else if (input->flags() & SocketType::LINK_TEXTURE_UV) {
141         if (shader->has_surface)
142           attributes->add(ATTR_STD_UV);
143       }
144     }
145   }
146 }
147
148 bool ShaderNode::equals(const ShaderNode &other)
149 {
150   if (type != other.type || bump != other.bump) {
151     return false;
152   }
153
154   assert(inputs.size() == other.inputs.size());
155
156   /* Compare unlinkable sockets */
157   foreach (const SocketType &socket, type->inputs) {
158     if (!(socket.flags & SocketType::LINKABLE)) {
159       if (!Node::equals_value(other, socket)) {
160         return false;
161       }
162     }
163   }
164
165   /* Compare linkable input sockets */
166   for (int i = 0; i < inputs.size(); ++i) {
167     ShaderInput *input_a = inputs[i], *input_b = other.inputs[i];
168     if (input_a->link == NULL && input_b->link == NULL) {
169       /* Unconnected inputs are expected to have the same value. */
170       if (!Node::equals_value(other, input_a->socket_type)) {
171         return false;
172       }
173     }
174     else if (input_a->link != NULL && input_b->link != NULL) {
175       /* Expect links are to come from the same exact socket. */
176       if (input_a->link != input_b->link) {
177         return false;
178       }
179     }
180     else {
181       /* One socket has a link and another has not, inputs can't be
182        * considered equal.
183        */
184       return false;
185     }
186   }
187
188   return true;
189 }
190
191 /* Graph */
192
193 ShaderGraph::ShaderGraph()
194 {
195   finalized = false;
196   simplified = false;
197   num_node_ids = 0;
198   add(new OutputNode());
199 }
200
201 ShaderGraph::~ShaderGraph()
202 {
203   clear_nodes();
204 }
205
206 ShaderNode *ShaderGraph::add(ShaderNode *node)
207 {
208   assert(!finalized);
209   simplified = false;
210
211   node->id = num_node_ids++;
212   nodes.push_back(node);
213   return node;
214 }
215
216 OutputNode *ShaderGraph::output()
217 {
218   return (OutputNode *)nodes.front();
219 }
220
221 void ShaderGraph::connect(ShaderOutput *from, ShaderInput *to)
222 {
223   assert(!finalized);
224   assert(from && to);
225
226   if (to->link) {
227     fprintf(stderr, "Cycles shader graph connect: input already connected.\n");
228     return;
229   }
230
231   if (from->type() != to->type()) {
232     /* can't do automatic conversion from closure */
233     if (from->type() == SocketType::CLOSURE) {
234       fprintf(stderr,
235               "Cycles shader graph connect: can only connect closure to closure "
236               "(%s.%s to %s.%s).\n",
237               from->parent->name.c_str(),
238               from->name().c_str(),
239               to->parent->name.c_str(),
240               to->name().c_str());
241       return;
242     }
243
244     /* add automatic conversion node in case of type mismatch */
245     ShaderNode *convert;
246     ShaderInput *convert_in;
247
248     if (to->type() == SocketType::CLOSURE) {
249       EmissionNode *emission = new EmissionNode();
250       emission->color = make_float3(1.0f, 1.0f, 1.0f);
251       emission->strength = 1.0f;
252       convert = add(emission);
253       /* Connect float inputs to Strength to save an additional Falue->Color conversion. */
254       if (from->type() == SocketType::FLOAT) {
255         convert_in = convert->input("Strength");
256       }
257       else {
258         convert_in = convert->input("Color");
259       }
260     }
261     else {
262       convert = add(new ConvertNode(from->type(), to->type(), true));
263       convert_in = convert->inputs[0];
264     }
265
266     connect(from, convert_in);
267     connect(convert->outputs[0], to);
268   }
269   else {
270     /* types match, just connect */
271     to->link = from;
272     from->links.push_back(to);
273   }
274 }
275
276 void ShaderGraph::disconnect(ShaderOutput *from)
277 {
278   assert(!finalized);
279   simplified = false;
280
281   foreach (ShaderInput *sock, from->links) {
282     sock->link = NULL;
283   }
284
285   from->links.clear();
286 }
287
288 void ShaderGraph::disconnect(ShaderInput *to)
289 {
290   assert(!finalized);
291   assert(to->link);
292   simplified = false;
293
294   ShaderOutput *from = to->link;
295
296   to->link = NULL;
297   from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
298 }
299
300 void ShaderGraph::relink(ShaderNode *node, ShaderOutput *from, ShaderOutput *to)
301 {
302   simplified = false;
303
304   /* Copy because disconnect modifies this list */
305   vector<ShaderInput *> outputs = from->links;
306
307   /* Bypass node by moving all links from "from" to "to" */
308   foreach (ShaderInput *sock, node->inputs) {
309     if (sock->link)
310       disconnect(sock);
311   }
312
313   foreach (ShaderInput *sock, outputs) {
314     disconnect(sock);
315     if (to)
316       connect(to, sock);
317   }
318 }
319
320 void ShaderGraph::simplify(Scene *scene)
321 {
322   if (!simplified) {
323     default_inputs(scene->shader_manager->use_osl());
324     clean(scene);
325     refine_bump_nodes();
326
327     simplified = true;
328   }
329 }
330
331 void ShaderGraph::finalize(Scene *scene, bool do_bump, bool do_simplify, bool bump_in_object_space)
332 {
333   /* before compiling, the shader graph may undergo a number of modifications.
334    * currently we set default geometry shader inputs, and create automatic bump
335    * from displacement. a graph can be finalized only once, and should not be
336    * modified afterwards. */
337
338   if (!finalized) {
339     simplify(scene);
340
341     if (do_bump)
342       bump_from_displacement(bump_in_object_space);
343
344     ShaderInput *surface_in = output()->input("Surface");
345     ShaderInput *volume_in = output()->input("Volume");
346
347     /* todo: make this work when surface and volume closures are tangled up */
348
349     if (surface_in->link)
350       transform_multi_closure(surface_in->link->parent, NULL, false);
351     if (volume_in->link)
352       transform_multi_closure(volume_in->link->parent, NULL, true);
353
354     finalized = true;
355   }
356   else if (do_simplify) {
357     simplify_settings(scene);
358   }
359 }
360
361 void ShaderGraph::find_dependencies(ShaderNodeSet &dependencies, ShaderInput *input)
362 {
363   /* find all nodes that this input depends on directly and indirectly */
364   ShaderNode *node = (input->link) ? input->link->parent : NULL;
365
366   if (node != NULL && dependencies.find(node) == dependencies.end()) {
367     foreach (ShaderInput *in, node->inputs)
368       find_dependencies(dependencies, in);
369
370     dependencies.insert(node);
371   }
372 }
373
374 void ShaderGraph::clear_nodes()
375 {
376   foreach (ShaderNode *node, nodes) {
377     delete node;
378   }
379   nodes.clear();
380 }
381
382 void ShaderGraph::copy_nodes(ShaderNodeSet &nodes, ShaderNodeMap &nnodemap)
383 {
384   /* copy a set of nodes, and the links between them. the assumption is
385    * made that all nodes that inputs are linked to are in the set too. */
386
387   /* copy nodes */
388   foreach (ShaderNode *node, nodes) {
389     ShaderNode *nnode = node->clone();
390     nnodemap[node] = nnode;
391
392     /* create new inputs and outputs to recreate links and ensure
393      * that we still point to valid SocketType if the NodeType
394      * changed in cloning, as it does for OSL nodes */
395     nnode->inputs.clear();
396     nnode->outputs.clear();
397     nnode->create_inputs_outputs(nnode->type);
398   }
399
400   /* recreate links */
401   foreach (ShaderNode *node, nodes) {
402     foreach (ShaderInput *input, node->inputs) {
403       if (input->link) {
404         /* find new input and output */
405         ShaderNode *nfrom = nnodemap[input->link->parent];
406         ShaderNode *nto = nnodemap[input->parent];
407         ShaderOutput *noutput = nfrom->output(input->link->name());
408         ShaderInput *ninput = nto->input(input->name());
409
410         /* connect */
411         connect(noutput, ninput);
412       }
413     }
414   }
415 }
416
417 /* Graph simplification */
418 /* ******************** */
419
420 /* Remove proxy nodes.
421  *
422  * These only exists temporarily when exporting groups, and we must remove them
423  * early so that node->attributes() and default links do not see them.
424  */
425 void ShaderGraph::remove_proxy_nodes()
426 {
427   vector<bool> removed(num_node_ids, false);
428   bool any_node_removed = false;
429
430   foreach (ShaderNode *node, nodes) {
431     if (node->special_type == SHADER_SPECIAL_TYPE_PROXY) {
432       ConvertNode *proxy = static_cast<ConvertNode *>(node);
433       ShaderInput *input = proxy->inputs[0];
434       ShaderOutput *output = proxy->outputs[0];
435
436       /* bypass the proxy node */
437       if (input->link) {
438         relink(proxy, output, input->link);
439       }
440       else {
441         /* Copy because disconnect modifies this list */
442         vector<ShaderInput *> links(output->links);
443
444         foreach (ShaderInput *to, links) {
445           /* remove any autoconvert nodes too if they lead to
446            * sockets with an automatically set default value */
447           ShaderNode *tonode = to->parent;
448
449           if (tonode->special_type == SHADER_SPECIAL_TYPE_AUTOCONVERT) {
450             bool all_links_removed = true;
451             vector<ShaderInput *> links = tonode->outputs[0]->links;
452
453             foreach (ShaderInput *autoin, links) {
454               if (autoin->flags() & SocketType::DEFAULT_LINK_MASK)
455                 disconnect(autoin);
456               else
457                 all_links_removed = false;
458             }
459
460             if (all_links_removed)
461               removed[tonode->id] = true;
462           }
463
464           disconnect(to);
465
466           /* transfer the default input value to the target socket */
467           tonode->copy_value(to->socket_type, *proxy, input->socket_type);
468         }
469       }
470
471       removed[proxy->id] = true;
472       any_node_removed = true;
473     }
474   }
475
476   /* remove nodes */
477   if (any_node_removed) {
478     list<ShaderNode *> newnodes;
479
480     foreach (ShaderNode *node, nodes) {
481       if (!removed[node->id])
482         newnodes.push_back(node);
483       else
484         delete node;
485     }
486
487     nodes = newnodes;
488   }
489 }
490
491 /* Constant folding.
492  *
493  * Try to constant fold some nodes, and pipe result directly to
494  * the input socket of connected nodes.
495  */
496 void ShaderGraph::constant_fold(Scene *scene)
497 {
498   ShaderNodeSet done, scheduled;
499   queue<ShaderNode *> traverse_queue;
500
501   bool has_displacement = (output()->input("Displacement")->link != NULL);
502
503   /* Schedule nodes which doesn't have any dependencies. */
504   foreach (ShaderNode *node, nodes) {
505     if (!check_node_inputs_has_links(node)) {
506       traverse_queue.push(node);
507       scheduled.insert(node);
508     }
509   }
510
511   while (!traverse_queue.empty()) {
512     ShaderNode *node = traverse_queue.front();
513     traverse_queue.pop();
514     done.insert(node);
515     foreach (ShaderOutput *output, node->outputs) {
516       if (output->links.size() == 0) {
517         continue;
518       }
519       /* Schedule node which was depending on the value,
520        * when possible. Do it before disconnect.
521        */
522       foreach (ShaderInput *input, output->links) {
523         if (scheduled.find(input->parent) != scheduled.end()) {
524           /* Node might not be optimized yet but scheduled already
525            * by other dependencies. No need to re-schedule it.
526            */
527           continue;
528         }
529         /* Schedule node if its inputs are fully done. */
530         if (check_node_inputs_traversed(input->parent, done)) {
531           traverse_queue.push(input->parent);
532           scheduled.insert(input->parent);
533         }
534       }
535       /* Optimize current node. */
536       ConstantFolder folder(this, node, output, scene);
537       node->constant_fold(folder);
538     }
539   }
540
541   /* Folding might have removed all nodes connected to the displacement output
542    * even tho there is displacement to be applied, so add in a value node if
543    * that happens to ensure there is still a valid graph for displacement.
544    */
545   if (has_displacement && !output()->input("Displacement")->link) {
546     ColorNode *value = (ColorNode *)add(new ColorNode());
547     value->value = output()->displacement;
548
549     connect(value->output("Color"), output()->input("Displacement"));
550   }
551 }
552
553 /* Simplification. */
554 void ShaderGraph::simplify_settings(Scene *scene)
555 {
556   foreach (ShaderNode *node, nodes) {
557     node->simplify_settings(scene);
558   }
559 }
560
561 /* Deduplicate nodes with same settings. */
562 void ShaderGraph::deduplicate_nodes()
563 {
564   /* NOTES:
565    * - Deduplication happens for nodes which has same exact settings and same
566    *   exact input links configuration (either connected to same output or has
567    *   the same exact default value).
568    * - Deduplication happens in the bottom-top manner, so we know for fact that
569    *   all traversed nodes are either can not be deduplicated at all or were
570    *   already deduplicated.
571    */
572
573   ShaderNodeSet scheduled, done;
574   map<ustring, ShaderNodeSet> candidates;
575   queue<ShaderNode *> traverse_queue;
576   int num_deduplicated = 0;
577
578   /* Schedule nodes which doesn't have any dependencies. */
579   foreach (ShaderNode *node, nodes) {
580     if (!check_node_inputs_has_links(node)) {
581       traverse_queue.push(node);
582       scheduled.insert(node);
583     }
584   }
585
586   while (!traverse_queue.empty()) {
587     ShaderNode *node = traverse_queue.front();
588     traverse_queue.pop();
589     done.insert(node);
590     /* Schedule the nodes which were depending on the current node. */
591     bool has_output_links = false;
592     foreach (ShaderOutput *output, node->outputs) {
593       foreach (ShaderInput *input, output->links) {
594         has_output_links = true;
595         if (scheduled.find(input->parent) != scheduled.end()) {
596           /* Node might not be optimized yet but scheduled already
597            * by other dependencies. No need to re-schedule it.
598            */
599           continue;
600         }
601         /* Schedule node if its inputs are fully done. */
602         if (check_node_inputs_traversed(input->parent, done)) {
603           traverse_queue.push(input->parent);
604           scheduled.insert(input->parent);
605         }
606       }
607     }
608     /* Only need to care about nodes that are actually used */
609     if (!has_output_links) {
610       continue;
611     }
612     /* Try to merge this node with another one. */
613     ShaderNode *merge_with = NULL;
614     foreach (ShaderNode *other_node, candidates[node->type->name]) {
615       if (node != other_node && node->equals(*other_node)) {
616         merge_with = other_node;
617         break;
618       }
619     }
620     /* If found an equivalent, merge; otherwise keep node for later merges */
621     if (merge_with != NULL) {
622       for (int i = 0; i < node->outputs.size(); ++i) {
623         relink(node, node->outputs[i], merge_with->outputs[i]);
624       }
625       num_deduplicated++;
626     }
627     else {
628       candidates[node->type->name].insert(node);
629     }
630   }
631
632   if (num_deduplicated > 0) {
633     VLOG(1) << "Deduplicated " << num_deduplicated << " nodes.";
634   }
635 }
636
637 /* Check whether volume output has meaningful nodes, otherwise
638  * disconnect the output.
639  */
640 void ShaderGraph::verify_volume_output()
641 {
642   /* Check whether we can optimize the whole volume graph out. */
643   ShaderInput *volume_in = output()->input("Volume");
644   if (volume_in->link == NULL) {
645     return;
646   }
647   bool has_valid_volume = false;
648   ShaderNodeSet scheduled;
649   queue<ShaderNode *> traverse_queue;
650   /* Schedule volume output. */
651   traverse_queue.push(volume_in->link->parent);
652   scheduled.insert(volume_in->link->parent);
653   /* Traverse down the tree. */
654   while (!traverse_queue.empty()) {
655     ShaderNode *node = traverse_queue.front();
656     traverse_queue.pop();
657     /* Node is fully valid for volume, can't optimize anything out. */
658     if (node->has_volume_support()) {
659       has_valid_volume = true;
660       break;
661     }
662     foreach (ShaderInput *input, node->inputs) {
663       if (input->link == NULL) {
664         continue;
665       }
666       if (scheduled.find(input->link->parent) != scheduled.end()) {
667         continue;
668       }
669       traverse_queue.push(input->link->parent);
670       scheduled.insert(input->link->parent);
671     }
672   }
673   if (!has_valid_volume) {
674     VLOG(1) << "Disconnect meaningless volume output.";
675     disconnect(volume_in->link);
676   }
677 }
678
679 void ShaderGraph::break_cycles(ShaderNode *node, vector<bool> &visited, vector<bool> &on_stack)
680 {
681   visited[node->id] = true;
682   on_stack[node->id] = true;
683
684   foreach (ShaderInput *input, node->inputs) {
685     if (input->link) {
686       ShaderNode *depnode = input->link->parent;
687
688       if (on_stack[depnode->id]) {
689         /* break cycle */
690         disconnect(input);
691         fprintf(stderr, "Cycles shader graph: detected cycle in graph, connection removed.\n");
692       }
693       else if (!visited[depnode->id]) {
694         /* visit dependencies */
695         break_cycles(depnode, visited, on_stack);
696       }
697     }
698   }
699
700   on_stack[node->id] = false;
701 }
702
703 void ShaderGraph::compute_displacement_hash()
704 {
705   /* Compute hash of all nodes linked to displacement, to detect if we need
706    * to recompute displacement when shader nodes change. */
707   ShaderInput *displacement_in = output()->input("Displacement");
708
709   if (!displacement_in->link) {
710     displacement_hash = "";
711     return;
712   }
713
714   ShaderNodeSet nodes_displace;
715   find_dependencies(nodes_displace, displacement_in);
716
717   MD5Hash md5;
718   foreach (ShaderNode *node, nodes_displace) {
719     node->hash(md5);
720     foreach (ShaderInput *input, node->inputs) {
721       int link_id = (input->link) ? input->link->parent->id : 0;
722       md5.append((uint8_t *)&link_id, sizeof(link_id));
723     }
724
725     if (node->special_type == SHADER_SPECIAL_TYPE_OSL) {
726       /* Hash takes into account socket values, to detect changes
727        * in the code of the node we need an exception. */
728       OSLNode *oslnode = static_cast<OSLNode *>(node);
729       md5.append(oslnode->bytecode_hash);
730     }
731   }
732
733   displacement_hash = md5.get_hex();
734 }
735
736 void ShaderGraph::clean(Scene *scene)
737 {
738   /* Graph simplification */
739
740   /* NOTE: Remove proxy nodes was already done. */
741   constant_fold(scene);
742   simplify_settings(scene);
743   deduplicate_nodes();
744   verify_volume_output();
745
746   /* we do two things here: find cycles and break them, and remove unused
747    * nodes that don't feed into the output. how cycles are broken is
748    * undefined, they are invalid input, the important thing is to not crash */
749
750   vector<bool> visited(num_node_ids, false);
751   vector<bool> on_stack(num_node_ids, false);
752
753   /* break cycles */
754   break_cycles(output(), visited, on_stack);
755
756   /* disconnect unused nodes */
757   foreach (ShaderNode *node, nodes) {
758     if (!visited[node->id]) {
759       foreach (ShaderInput *to, node->inputs) {
760         ShaderOutput *from = to->link;
761
762         if (from) {
763           to->link = NULL;
764           from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
765         }
766       }
767     }
768   }
769
770   /* remove unused nodes */
771   list<ShaderNode *> newnodes;
772
773   foreach (ShaderNode *node, nodes) {
774     if (visited[node->id])
775       newnodes.push_back(node);
776     else
777       delete node;
778   }
779
780   nodes = newnodes;
781 }
782
783 void ShaderGraph::default_inputs(bool do_osl)
784 {
785   /* nodes can specify default texture coordinates, for now we give
786    * everything the position by default, except for the sky texture */
787
788   ShaderNode *geom = NULL;
789   ShaderNode *texco = NULL;
790
791   foreach (ShaderNode *node, nodes) {
792     foreach (ShaderInput *input, node->inputs) {
793       if (!input->link && (!(input->flags() & SocketType::OSL_INTERNAL) || do_osl)) {
794         if (input->flags() & SocketType::LINK_TEXTURE_GENERATED) {
795           if (!texco)
796             texco = new TextureCoordinateNode();
797
798           connect(texco->output("Generated"), input);
799         }
800         if (input->flags() & SocketType::LINK_TEXTURE_NORMAL) {
801           if (!texco)
802             texco = new TextureCoordinateNode();
803
804           connect(texco->output("Normal"), input);
805         }
806         else if (input->flags() & SocketType::LINK_TEXTURE_UV) {
807           if (!texco)
808             texco = new TextureCoordinateNode();
809
810           connect(texco->output("UV"), input);
811         }
812         else if (input->flags() & SocketType::LINK_INCOMING) {
813           if (!geom)
814             geom = new GeometryNode();
815
816           connect(geom->output("Incoming"), input);
817         }
818         else if (input->flags() & SocketType::LINK_NORMAL) {
819           if (!geom)
820             geom = new GeometryNode();
821
822           connect(geom->output("Normal"), input);
823         }
824         else if (input->flags() & SocketType::LINK_POSITION) {
825           if (!geom)
826             geom = new GeometryNode();
827
828           connect(geom->output("Position"), input);
829         }
830         else if (input->flags() & SocketType::LINK_TANGENT) {
831           if (!geom)
832             geom = new GeometryNode();
833
834           connect(geom->output("Tangent"), input);
835         }
836       }
837     }
838   }
839
840   if (geom)
841     add(geom);
842   if (texco)
843     add(texco);
844 }
845
846 void ShaderGraph::refine_bump_nodes()
847 {
848   /* we transverse the node graph looking for bump nodes, when we find them,
849    * like in bump_from_displacement(), we copy the sub-graph defined from "bump"
850    * input to the inputs "center","dx" and "dy" What is in "bump" input is moved
851    * to "center" input. */
852
853   foreach (ShaderNode *node, nodes) {
854     if (node->special_type == SHADER_SPECIAL_TYPE_BUMP && node->input("Height")->link) {
855       ShaderInput *bump_input = node->input("Height");
856       ShaderNodeSet nodes_bump;
857
858       /* make 2 extra copies of the subgraph defined in Bump input */
859       ShaderNodeMap nodes_dx;
860       ShaderNodeMap nodes_dy;
861
862       /* find dependencies for the given input */
863       find_dependencies(nodes_bump, bump_input);
864
865       copy_nodes(nodes_bump, nodes_dx);
866       copy_nodes(nodes_bump, nodes_dy);
867
868       /* mark nodes to indicate they are use for bump computation, so
869          that any texture coordinates are shifted by dx/dy when sampling */
870       foreach (ShaderNode *node, nodes_bump)
871         node->bump = SHADER_BUMP_CENTER;
872       foreach (NodePair &pair, nodes_dx)
873         pair.second->bump = SHADER_BUMP_DX;
874       foreach (NodePair &pair, nodes_dy)
875         pair.second->bump = SHADER_BUMP_DY;
876
877       ShaderOutput *out = bump_input->link;
878       ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name());
879       ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name());
880
881       connect(out_dx, node->input("SampleX"));
882       connect(out_dy, node->input("SampleY"));
883
884       /* add generated nodes */
885       foreach (NodePair &pair, nodes_dx)
886         add(pair.second);
887       foreach (NodePair &pair, nodes_dy)
888         add(pair.second);
889
890       /* connect what is connected is bump to samplecenter input*/
891       connect(out, node->input("SampleCenter"));
892
893       /* bump input is just for connectivity purpose for the graph input,
894        * we re-connected this input to samplecenter, so lets disconnect it
895        * from bump input */
896       disconnect(bump_input);
897     }
898   }
899 }
900
901 void ShaderGraph::bump_from_displacement(bool use_object_space)
902 {
903   /* generate bump mapping automatically from displacement. bump mapping is
904    * done using a 3-tap filter, computing the displacement at the center,
905    * and two other positions shifted by ray differentials.
906    *
907    * since the input to displacement is a node graph, we need to ensure that
908    * all texture coordinates use are shift by the ray differentials. for this
909    * reason we make 3 copies of the node subgraph defining the displacement,
910    * with each different geometry and texture coordinate nodes that generate
911    * different shifted coordinates.
912    *
913    * these 3 displacement values are then fed into the bump node, which will
914    * output the perturbed normal. */
915
916   ShaderInput *displacement_in = output()->input("Displacement");
917
918   if (!displacement_in->link)
919     return;
920
921   /* find dependencies for the given input */
922   ShaderNodeSet nodes_displace;
923   find_dependencies(nodes_displace, displacement_in);
924
925   /* copy nodes for 3 bump samples */
926   ShaderNodeMap nodes_center;
927   ShaderNodeMap nodes_dx;
928   ShaderNodeMap nodes_dy;
929
930   copy_nodes(nodes_displace, nodes_center);
931   copy_nodes(nodes_displace, nodes_dx);
932   copy_nodes(nodes_displace, nodes_dy);
933
934   /* mark nodes to indicate they are use for bump computation, so
935    * that any texture coordinates are shifted by dx/dy when sampling */
936   foreach (NodePair &pair, nodes_center)
937     pair.second->bump = SHADER_BUMP_CENTER;
938   foreach (NodePair &pair, nodes_dx)
939     pair.second->bump = SHADER_BUMP_DX;
940   foreach (NodePair &pair, nodes_dy)
941     pair.second->bump = SHADER_BUMP_DY;
942
943   /* add set normal node and connect the bump normal ouput to the set normal
944    * output, so it can finally set the shader normal, note we are only doing
945    * this for bump from displacement, this will be the only bump allowed to
946    * overwrite the shader normal */
947   ShaderNode *set_normal = add(new SetNormalNode());
948
949   /* add bump node and connect copied graphs to it */
950   BumpNode *bump = (BumpNode *)add(new BumpNode());
951   bump->use_object_space = use_object_space;
952   bump->distance = 1.0f;
953
954   ShaderOutput *out = displacement_in->link;
955   ShaderOutput *out_center = nodes_center[out->parent]->output(out->name());
956   ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name());
957   ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name());
958
959   /* convert displacement vector to height */
960   VectorMathNode *dot_center = (VectorMathNode *)add(new VectorMathNode());
961   VectorMathNode *dot_dx = (VectorMathNode *)add(new VectorMathNode());
962   VectorMathNode *dot_dy = (VectorMathNode *)add(new VectorMathNode());
963
964   dot_center->type = NODE_VECTOR_MATH_DOT_PRODUCT;
965   dot_dx->type = NODE_VECTOR_MATH_DOT_PRODUCT;
966   dot_dy->type = NODE_VECTOR_MATH_DOT_PRODUCT;
967
968   GeometryNode *geom = (GeometryNode *)add(new GeometryNode());
969   connect(geom->output("Normal"), dot_center->input("Vector2"));
970   connect(geom->output("Normal"), dot_dx->input("Vector2"));
971   connect(geom->output("Normal"), dot_dy->input("Vector2"));
972
973   connect(out_center, dot_center->input("Vector1"));
974   connect(out_dx, dot_dx->input("Vector1"));
975   connect(out_dy, dot_dy->input("Vector1"));
976
977   connect(dot_center->output("Value"), bump->input("SampleCenter"));
978   connect(dot_dx->output("Value"), bump->input("SampleX"));
979   connect(dot_dy->output("Value"), bump->input("SampleY"));
980
981   /* connect the bump out to the set normal in: */
982   connect(bump->output("Normal"), set_normal->input("Direction"));
983
984   /* connect to output node */
985   connect(set_normal->output("Normal"), output()->input("Normal"));
986
987   /* finally, add the copied nodes to the graph. we can't do this earlier
988    * because we would create dependency cycles in the above loop */
989   foreach (NodePair &pair, nodes_center)
990     add(pair.second);
991   foreach (NodePair &pair, nodes_dx)
992     add(pair.second);
993   foreach (NodePair &pair, nodes_dy)
994     add(pair.second);
995 }
996
997 void ShaderGraph::transform_multi_closure(ShaderNode *node, ShaderOutput *weight_out, bool volume)
998 {
999   /* for SVM in multi closure mode, this transforms the shader mix/add part of
1000    * the graph into nodes that feed weights into closure nodes. this is too
1001    * avoid building a closure tree and then flattening it, and instead write it
1002    * directly to an array */
1003
1004   if (node->special_type == SHADER_SPECIAL_TYPE_COMBINE_CLOSURE) {
1005     ShaderInput *fin = node->input("Fac");
1006     ShaderInput *cl1in = node->input("Closure1");
1007     ShaderInput *cl2in = node->input("Closure2");
1008     ShaderOutput *weight1_out, *weight2_out;
1009
1010     if (fin) {
1011       /* mix closure: add node to mix closure weights */
1012       MixClosureWeightNode *mix_node = new MixClosureWeightNode();
1013       add(mix_node);
1014       ShaderInput *fac_in = mix_node->input("Fac");
1015       ShaderInput *weight_in = mix_node->input("Weight");
1016
1017       if (fin->link)
1018         connect(fin->link, fac_in);
1019       else
1020         mix_node->fac = node->get_float(fin->socket_type);
1021
1022       if (weight_out)
1023         connect(weight_out, weight_in);
1024
1025       weight1_out = mix_node->output("Weight1");
1026       weight2_out = mix_node->output("Weight2");
1027     }
1028     else {
1029       /* add closure: just pass on any weights */
1030       weight1_out = weight_out;
1031       weight2_out = weight_out;
1032     }
1033
1034     if (cl1in->link)
1035       transform_multi_closure(cl1in->link->parent, weight1_out, volume);
1036     if (cl2in->link)
1037       transform_multi_closure(cl2in->link->parent, weight2_out, volume);
1038   }
1039   else {
1040     ShaderInput *weight_in = node->input((volume) ? "VolumeMixWeight" : "SurfaceMixWeight");
1041
1042     /* not a closure node? */
1043     if (!weight_in)
1044       return;
1045
1046     /* already has a weight connected to it? add weights */
1047     float weight_value = node->get_float(weight_in->socket_type);
1048     if (weight_in->link || weight_value != 0.0f) {
1049       MathNode *math_node = new MathNode();
1050       add(math_node);
1051
1052       if (weight_in->link)
1053         connect(weight_in->link, math_node->input("Value1"));
1054       else
1055         math_node->value1 = weight_value;
1056
1057       if (weight_out)
1058         connect(weight_out, math_node->input("Value2"));
1059       else
1060         math_node->value2 = 1.0f;
1061
1062       weight_out = math_node->output("Value");
1063       if (weight_in->link)
1064         disconnect(weight_in);
1065     }
1066
1067     /* connected to closure mix weight */
1068     if (weight_out)
1069       connect(weight_out, weight_in);
1070     else
1071       node->set(weight_in->socket_type, weight_value + 1.0f);
1072   }
1073 }
1074
1075 int ShaderGraph::get_num_closures()
1076 {
1077   int num_closures = 0;
1078   foreach (ShaderNode *node, nodes) {
1079     ClosureType closure_type = node->get_closure_type();
1080     if (closure_type == CLOSURE_NONE_ID) {
1081       continue;
1082     }
1083     else if (CLOSURE_IS_BSSRDF(closure_type)) {
1084       num_closures += 3;
1085     }
1086     else if (CLOSURE_IS_GLASS(closure_type)) {
1087       num_closures += 2;
1088     }
1089     else if (CLOSURE_IS_BSDF_MULTISCATTER(closure_type)) {
1090       num_closures += 2;
1091     }
1092     else if (CLOSURE_IS_PRINCIPLED(closure_type)) {
1093       num_closures += 8;
1094     }
1095     else if (CLOSURE_IS_VOLUME(closure_type)) {
1096       num_closures += VOLUME_STACK_SIZE;
1097     }
1098     else if (closure_type == CLOSURE_BSDF_HAIR_PRINCIPLED_ID) {
1099       num_closures += 4;
1100     }
1101     else {
1102       ++num_closures;
1103     }
1104   }
1105   return num_closures;
1106 }
1107
1108 void ShaderGraph::dump_graph(const char *filename)
1109 {
1110   FILE *fd = fopen(filename, "w");
1111
1112   if (fd == NULL) {
1113     printf("Error opening file for dumping the graph: %s\n", filename);
1114     return;
1115   }
1116
1117   fprintf(fd, "digraph shader_graph {\n");
1118   fprintf(fd, "ranksep=1.5\n");
1119   fprintf(fd, "rankdir=LR\n");
1120   fprintf(fd, "splines=false\n");
1121
1122   foreach (ShaderNode *node, nodes) {
1123     fprintf(fd, "// NODE: %p\n", node);
1124     fprintf(fd, "\"%p\" [shape=record,label=\"{", node);
1125     if (node->inputs.size()) {
1126       fprintf(fd, "{");
1127       foreach (ShaderInput *socket, node->inputs) {
1128         if (socket != node->inputs[0]) {
1129           fprintf(fd, "|");
1130         }
1131         fprintf(fd, "<IN_%p>%s", socket, socket->name().c_str());
1132       }
1133       fprintf(fd, "}|");
1134     }
1135     fprintf(fd, "%s", node->name.c_str());
1136     if (node->bump == SHADER_BUMP_CENTER) {
1137       fprintf(fd, " (bump:center)");
1138     }
1139     else if (node->bump == SHADER_BUMP_DX) {
1140       fprintf(fd, " (bump:dx)");
1141     }
1142     else if (node->bump == SHADER_BUMP_DY) {
1143       fprintf(fd, " (bump:dy)");
1144     }
1145     if (node->outputs.size()) {
1146       fprintf(fd, "|{");
1147       foreach (ShaderOutput *socket, node->outputs) {
1148         if (socket != node->outputs[0]) {
1149           fprintf(fd, "|");
1150         }
1151         fprintf(fd, "<OUT_%p>%s", socket, socket->name().c_str());
1152       }
1153       fprintf(fd, "}");
1154     }
1155     fprintf(fd, "}\"]");
1156   }
1157
1158   foreach (ShaderNode *node, nodes) {
1159     foreach (ShaderOutput *output, node->outputs) {
1160       foreach (ShaderInput *input, output->links) {
1161         fprintf(fd,
1162                 "// CONNECTION: OUT_%p->IN_%p (%s:%s)\n",
1163                 output,
1164                 input,
1165                 output->name().c_str(),
1166                 input->name().c_str());
1167         fprintf(fd,
1168                 "\"%p\":\"OUT_%p\":e -> \"%p\":\"IN_%p\":w [label=\"\"]\n",
1169                 output->parent,
1170                 output,
1171                 input->parent,
1172                 input);
1173       }
1174     }
1175   }
1176
1177   fprintf(fd, "}\n");
1178   fclose(fd);
1179 }
1180
1181 CCL_NAMESPACE_END