6ed0812a23973a563e0cd9e105d1301b419ffa59
[blender-staging.git] / intern / cycles / render / graph.cpp
1 /*
2  * Copyright 2011, Blender Foundation.
3  *
4  * This program is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU General Public License
6  * as published by the Free Software Foundation; either version 2
7  * of the License, or (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software Foundation,
16  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
17  */
18
19 #include "attribute.h"
20 #include "graph.h"
21 #include "nodes.h"
22
23 #include "util_algorithm.h"
24 #include "util_debug.h"
25 #include "util_foreach.h"
26
27 CCL_NAMESPACE_BEGIN
28
29 /* Input and Output */
30
31 ShaderInput::ShaderInput(ShaderNode *parent_, const char *name_, ShaderSocketType type_)
32 {
33         parent = parent_;
34         name = name_;
35         type = type_;
36         link = NULL;
37         value = make_float3(0, 0, 0);
38         stack_offset = SVM_STACK_INVALID;
39         default_value = NONE;
40         osl_only = false;
41 }
42
43 ShaderOutput::ShaderOutput(ShaderNode *parent_, const char *name_, ShaderSocketType type_)
44 {
45         parent = parent_;
46         name = name_;
47         type = type_;
48         stack_offset = SVM_STACK_INVALID;
49 }
50
51 /* Node */
52
53 ShaderNode::ShaderNode(const char *name_)
54 {
55         name = name_;
56         id = -1;
57         bump = SHADER_BUMP_NONE;
58         special_type = SHADER_SPECIAL_TYPE_NONE;
59 }
60
61 ShaderNode::~ShaderNode()
62 {
63         foreach(ShaderInput *socket, inputs)
64                 delete socket;
65
66         foreach(ShaderOutput *socket, outputs)
67                 delete socket;
68 }
69
70 ShaderInput *ShaderNode::input(const char *name)
71 {
72         foreach(ShaderInput *socket, inputs)
73                 if(strcmp(socket->name, name) == 0)
74                         return socket;
75
76         return NULL;
77 }
78
79 ShaderOutput *ShaderNode::output(const char *name)
80 {
81         foreach(ShaderOutput *socket, outputs)
82                 if(strcmp(socket->name, name) == 0)
83                         return socket;
84
85         return NULL;
86 }
87
88 ShaderInput *ShaderNode::add_input(const char *name, ShaderSocketType type, float value)
89 {
90         ShaderInput *input = new ShaderInput(this, name, type);
91         input->value.x = value;
92         inputs.push_back(input);
93         return input;
94 }
95
96 ShaderInput *ShaderNode::add_input(const char *name, ShaderSocketType type, float3 value)
97 {
98         ShaderInput *input = new ShaderInput(this, name, type);
99         input->value = value;
100         inputs.push_back(input);
101         return input;
102 }
103
104 ShaderInput *ShaderNode::add_input(const char *name, ShaderSocketType type, ShaderInput::DefaultValue value, bool osl_only)
105 {
106         ShaderInput *input = add_input(name, type);
107         input->default_value = value;
108         input->osl_only = osl_only;
109         return input;
110 }
111
112 ShaderOutput *ShaderNode::add_output(const char *name, ShaderSocketType type)
113 {
114         ShaderOutput *output = new ShaderOutput(this, name, type);
115         outputs.push_back(output);
116         return output;
117 }
118
119 void ShaderNode::attributes(AttributeRequestSet *attributes)
120 {
121         foreach(ShaderInput *input, inputs) {
122                 if(!input->link) {
123                         if(input->default_value == ShaderInput::TEXTURE_GENERATED)
124                                 attributes->add(ATTR_STD_GENERATED);
125                         else if(input->default_value == ShaderInput::TEXTURE_UV)
126                                 attributes->add(ATTR_STD_UV);
127                 }
128         }
129 }
130
131 /* Graph */
132
133 ShaderGraph::ShaderGraph()
134 {
135         finalized = false;
136         add(new OutputNode());
137 }
138
139 ShaderGraph::~ShaderGraph()
140 {
141         foreach(ShaderNode *node, nodes)
142                 delete node;
143 }
144
145 ShaderNode *ShaderGraph::add(ShaderNode *node)
146 {
147         assert(!finalized);
148         node->id = nodes.size();
149         nodes.push_back(node);
150         return node;
151 }
152
153 ShaderNode *ShaderGraph::output()
154 {
155         return nodes.front();
156 }
157
158 ShaderGraph *ShaderGraph::copy()
159 {
160         ShaderGraph *newgraph = new ShaderGraph();
161
162         /* copy nodes */
163         set<ShaderNode*> nodes_all;
164         foreach(ShaderNode *node, nodes)
165                 nodes_all.insert(node);
166
167         map<ShaderNode*, ShaderNode*> nodes_copy;
168         copy_nodes(nodes_all, nodes_copy);
169
170         /* add nodes (in same order, so output is still first) */
171         newgraph->nodes.clear();
172         foreach(ShaderNode *node, nodes)
173                 newgraph->add(nodes_copy[node]);
174
175         return newgraph;
176 }
177
178 void ShaderGraph::connect(ShaderOutput *from, ShaderInput *to)
179 {
180         assert(!finalized);
181         assert(from && to);
182
183         if(to->link) {
184                 fprintf(stderr, "ShaderGraph connect: input already connected.\n");
185                 return;
186         }
187
188         if(from->type != to->type) {
189                 /* for closures we can't do automatic conversion */
190                 if(from->type == SHADER_SOCKET_CLOSURE || to->type == SHADER_SOCKET_CLOSURE) {
191                         fprintf(stderr, "ShaderGraph connect: can only connect closure to closure "
192                                 "(ShaderNode:%s, ShaderOutput:%s , type:%d -> to ShaderNode:%s, ShaderInput:%s, type:%d).\n",
193                                 from->parent->name.c_str(), from->name, (int)from->type,
194                                 to->parent->name.c_str(),   to->name,   (int)to->type);
195                         return;
196                 }
197
198                 /* add automatic conversion node in case of type mismatch */
199                 ShaderNode *convert = add(new ConvertNode(from->type, to->type));
200
201                 connect(from, convert->inputs[0]);
202                 connect(convert->outputs[0], to);
203         }
204         else {
205                 /* types match, just connect */
206                 to->link = from;
207                 from->links.push_back(to);
208         }
209 }
210
211 void ShaderGraph::disconnect(ShaderInput *to)
212 {
213         assert(!finalized);
214         assert(to->link);
215
216         ShaderOutput *from = to->link;
217
218         to->link = NULL;
219         from->links.erase(remove(from->links.begin(), from->links.end(), to), from->links.end());
220 }
221
222 void ShaderGraph::finalize(bool do_bump, bool do_osl)
223 {
224         /* before compiling, the shader graph may undergo a number of modifications.
225          * currently we set default geometry shader inputs, and create automatic bump
226          * from displacement. a graph can be finalized only once, and should not be
227          * modified afterwards. */
228
229         if(!finalized) {
230                 clean();
231                 default_inputs(do_osl);
232                 if(do_bump)
233                         bump_from_displacement();
234
235                 finalized = true;
236         }
237 }
238
239 void ShaderGraph::find_dependencies(set<ShaderNode*>& dependencies, ShaderInput *input)
240 {
241         /* find all nodes that this input dependes on directly and indirectly */
242         ShaderNode *node = (input->link)? input->link->parent: NULL;
243
244         if(node) {
245                 foreach(ShaderInput *in, node->inputs)
246                         find_dependencies(dependencies, in);
247
248                 dependencies.insert(node);
249         }
250 }
251
252 void ShaderGraph::copy_nodes(set<ShaderNode*>& nodes, map<ShaderNode*, ShaderNode*>& nnodemap)
253 {
254         /* copy a set of nodes, and the links between them. the assumption is
255          * made that all nodes that inputs are linked to are in the set too. */
256
257         /* copy nodes */
258         foreach(ShaderNode *node, nodes) {
259                 ShaderNode *nnode = node->clone();
260                 nnodemap[node] = nnode;
261
262                 nnode->inputs.clear();
263                 nnode->outputs.clear();
264
265                 foreach(ShaderInput *input, node->inputs) {
266                         ShaderInput *ninput = new ShaderInput(*input);
267                         nnode->inputs.push_back(ninput);
268
269                         ninput->parent = nnode;
270                         ninput->link = NULL;
271                 }
272
273                 foreach(ShaderOutput *output, node->outputs) {
274                         ShaderOutput *noutput = new ShaderOutput(*output);
275                         nnode->outputs.push_back(noutput);
276
277                         noutput->parent = nnode;
278                         noutput->links.clear();
279                 }
280         }
281
282         /* recreate links */
283         foreach(ShaderNode *node, nodes) {
284                 foreach(ShaderInput *input, node->inputs) {
285                         if(input->link) {
286                                 /* find new input and output */
287                                 ShaderNode *nfrom = nnodemap[input->link->parent];
288                                 ShaderNode *nto = nnodemap[input->parent];
289                                 ShaderOutput *noutput = nfrom->output(input->link->name);
290                                 ShaderInput *ninput = nto->input(input->name);
291
292                                 /* connect */
293                                 connect(noutput, ninput);
294                         }
295                 }
296         }
297 }
298
299 void ShaderGraph::remove_proxy_nodes(vector<bool>& removed)
300 {
301         foreach(ShaderNode *node, nodes) {
302                 if (node->special_type == SHADER_SPECIAL_TYPE_PROXY) {
303                         ProxyNode *proxy = static_cast<ProxyNode*>(node);
304                         ShaderInput *input = proxy->inputs[0];
305                         ShaderOutput *output = proxy->outputs[0];
306                         
307                         /* temp. copy of the output links list.
308                          * output->links is modified when we disconnect!
309                          */
310                         vector<ShaderInput*> links(output->links);
311                         ShaderOutput *from = input->link;
312                         
313                         /* bypass the proxy node */
314                         if (from) {
315                                 disconnect(input);
316                                 foreach(ShaderInput *to, links) {
317                                         disconnect(to);
318                                         connect(from, to);
319                                 }
320                         }
321                         else {
322                                 foreach(ShaderInput *to, links) {
323                                         disconnect(to);
324                                         
325                                         /* transfer the default input value to the target socket */
326                                         to->set(input->value);
327                                 }
328                         }
329                         
330                         removed[proxy->id] = true;
331                 }
332
333                 /* remove useless mix closures nodes */
334                 if(node->special_type == SHADER_SPECIAL_TYPE_MIX_CLOSURE) {
335                         MixClosureNode *mix = static_cast<MixClosureNode*>(node);
336                         if(mix->outputs[0]->links.size() && mix->inputs[1]->link == mix->inputs[2]->link) {
337                                 ShaderOutput *output = mix->inputs[1]->link;
338                                 vector<ShaderInput*> inputs = mix->outputs[0]->links;
339
340                                 foreach(ShaderInput *sock, mix->inputs)
341                                         if(sock->link)
342                                                 disconnect(sock);
343
344                                 foreach(ShaderInput *input, inputs) {
345                                         disconnect(input);
346                                         if (output)
347                                                 connect(output, input);
348                                 }
349                         }
350                 }
351         }
352 }
353
354 void ShaderGraph::break_cycles(ShaderNode *node, vector<bool>& visited, vector<bool>& on_stack)
355 {
356         visited[node->id] = true;
357         on_stack[node->id] = true;
358
359         foreach(ShaderInput *input, node->inputs) {
360                 if(input->link) {
361                         ShaderNode *depnode = input->link->parent;
362
363                         if(on_stack[depnode->id]) {
364                                 /* break cycle */
365                                 disconnect(input);
366                                 fprintf(stderr, "ShaderGraph: detected cycle in graph, connection removed.\n");
367                         }
368                         else if(!visited[depnode->id]) {
369                                 /* visit dependencies */
370                                 break_cycles(depnode, visited, on_stack);
371                         }
372                 }
373         }
374
375         on_stack[node->id] = false;
376 }
377
378 void ShaderGraph::clean()
379 {
380         /* we do two things here: find cycles and break them, and remove unused
381          * nodes that don't feed into the output. how cycles are broken is
382          * undefined, they are invalid input, the important thing is to not crash */
383
384         vector<bool> removed(nodes.size(), false);
385         vector<bool> visited(nodes.size(), false);
386         vector<bool> on_stack(nodes.size(), false);
387         
388         list<ShaderNode*> newnodes;
389         
390         /* remove proxy nodes */
391         remove_proxy_nodes(removed);
392         
393         foreach(ShaderNode *node, nodes) {
394                 if(!removed[node->id])
395                         newnodes.push_back(node);
396                 else
397                         delete node;
398         }
399         nodes = newnodes;
400         newnodes.clear();
401
402         /* break cycles */
403         break_cycles(output(), visited, on_stack);
404
405         /* remove unused nodes */
406         foreach(ShaderNode *node, nodes) {
407                 if(visited[node->id])
408                         newnodes.push_back(node);
409                 else
410                         delete node;
411         }
412         
413         nodes = newnodes;
414 }
415
416 void ShaderGraph::default_inputs(bool do_osl)
417 {
418         /* nodes can specify default texture coordinates, for now we give
419          * everything the position by default, except for the sky texture */
420
421         ShaderNode *geom = NULL;
422         ShaderNode *texco = NULL;
423
424         foreach(ShaderNode *node, nodes) {
425                 foreach(ShaderInput *input, node->inputs) {
426                         if(!input->link && !(input->osl_only && !do_osl)) {
427                                 if(input->default_value == ShaderInput::TEXTURE_GENERATED) {
428                                         if(!texco)
429                                                 texco = new TextureCoordinateNode();
430
431                                         connect(texco->output("Generated"), input);
432                                 }
433                                 else if(input->default_value == ShaderInput::TEXTURE_UV) {
434                                         if(!texco)
435                                                 texco = new TextureCoordinateNode();
436
437                                         connect(texco->output("UV"), input);
438                                 }
439                                 else if(input->default_value == ShaderInput::INCOMING) {
440                                         if(!geom)
441                                                 geom = new GeometryNode();
442
443                                         connect(geom->output("Incoming"), input);
444                                 }
445                                 else if(input->default_value == ShaderInput::NORMAL) {
446                                         if(!geom)
447                                                 geom = new GeometryNode();
448
449                                         connect(geom->output("Normal"), input);
450                                 }
451                                 else if(input->default_value == ShaderInput::POSITION) {
452                                         if(!geom)
453                                                 geom = new GeometryNode();
454
455                                         connect(geom->output("Position"), input);
456                                 }
457                         }
458                 }
459         }
460
461         if(geom)
462                 add(geom);
463         if(texco)
464                 add(texco);
465 }
466
467 void ShaderGraph::bump_from_displacement()
468 {
469         /* generate bump mapping automatically from displacement. bump mapping is
470          * done using a 3-tap filter, computing the displacement at the center,
471          * and two other positions shifted by ray differentials.
472          *
473          * since the input to displacement is a node graph, we need to ensure that
474          * all texture coordinates use are shift by the ray differentials. for this
475          * reason we make 3 copies of the node subgraph defining the displacement,
476          * with each different geometry and texture coordinate nodes that generate
477          * different shifted coordinates.
478          *
479          * these 3 displacement values are then fed into the bump node, which will
480          * modify the normal. */
481
482         ShaderInput *displacement_in = output()->input("Displacement");
483
484         if(!displacement_in->link)
485                 return;
486         
487         /* find dependencies for the given input */
488         set<ShaderNode*> nodes_displace;
489         find_dependencies(nodes_displace, displacement_in);
490
491         /* copy nodes for 3 bump samples */
492         map<ShaderNode*, ShaderNode*> nodes_center;
493         map<ShaderNode*, ShaderNode*> nodes_dx;
494         map<ShaderNode*, ShaderNode*> nodes_dy;
495
496         copy_nodes(nodes_displace, nodes_center);
497         copy_nodes(nodes_displace, nodes_dx);
498         copy_nodes(nodes_displace, nodes_dy);
499
500         /* mark nodes to indicate they are use for bump computation, so
501          * that any texture coordinates are shifted by dx/dy when sampling */
502         foreach(NodePair& pair, nodes_center)
503                 pair.second->bump = SHADER_BUMP_CENTER;
504         foreach(NodePair& pair, nodes_dx)
505                 pair.second->bump = SHADER_BUMP_DX;
506         foreach(NodePair& pair, nodes_dy)
507                 pair.second->bump = SHADER_BUMP_DY;
508
509         /* add bump node and connect copied graphs to it */
510         ShaderNode *bump = add(new BumpNode());
511
512         ShaderOutput *out = displacement_in->link;
513         ShaderOutput *out_center = nodes_center[out->parent]->output(out->name);
514         ShaderOutput *out_dx = nodes_dx[out->parent]->output(out->name);
515         ShaderOutput *out_dy = nodes_dy[out->parent]->output(out->name);
516
517         connect(out_center, bump->input("SampleCenter"));
518         connect(out_dx, bump->input("SampleX"));
519         connect(out_dy, bump->input("SampleY"));
520
521         /* connect bump output to normal input nodes that aren't set yet. actually
522          * this will only set the normal input to the geometry node that we created
523          * and connected to all other normal inputs already. */
524         foreach(ShaderNode *node, nodes)
525                 foreach(ShaderInput *input, node->inputs)
526                         if(!input->link && input->default_value == ShaderInput::NORMAL)
527                                 connect(bump->output("Normal"), input);
528         
529         /* finally, add the copied nodes to the graph. we can't do this earlier
530          * because we would create dependency cycles in the above loop */
531         foreach(NodePair& pair, nodes_center)
532                 add(pair.second);
533         foreach(NodePair& pair, nodes_dx)
534                 add(pair.second);
535         foreach(NodePair& pair, nodes_dy)
536                 add(pair.second);
537 }
538
539 CCL_NAMESPACE_END
540