Merge branch 'master' into blender2.8
[blender.git] / intern / cycles / render / constant_fold.cpp
1 /*
2  * Copyright 2011-2013 Blender Foundation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "render/constant_fold.h"
18 #include "render/graph.h"
19
20 #include "util/util_foreach.h"
21 #include "util/util_logging.h"
22
23 CCL_NAMESPACE_BEGIN
24
25 ConstantFolder::ConstantFolder(ShaderGraph *graph, ShaderNode *node, ShaderOutput *output, Scene *scene)
26 : graph(graph), node(node), output(output), scene(scene)
27 {
28 }
29
30 bool ConstantFolder::all_inputs_constant() const
31 {
32         foreach(ShaderInput *input, node->inputs) {
33                 if(input->link) {
34                         return false;
35                 }
36         }
37
38         return true;
39 }
40
41 void ConstantFolder::make_constant(float value) const
42 {
43         VLOG(1) << "Folding " << node->name << "::" << output->name() << " to constant (" << value << ").";
44
45         foreach(ShaderInput *sock, output->links) {
46                 sock->set(value);
47         }
48
49         graph->disconnect(output);
50 }
51
52 void ConstantFolder::make_constant(float3 value) const
53 {
54         VLOG(1) << "Folding " << node->name << "::" << output->name() << " to constant " << value << ".";
55
56         foreach(ShaderInput *sock, output->links) {
57                 sock->set(value);
58         }
59
60         graph->disconnect(output);
61 }
62
63 void ConstantFolder::make_constant_clamp(float value, bool clamp) const
64 {
65         make_constant(clamp ? saturate(value) : value);
66 }
67
68 void ConstantFolder::make_constant_clamp(float3 value, bool clamp) const
69 {
70         if(clamp) {
71                 value.x = saturate(value.x);
72                 value.y = saturate(value.y);
73                 value.z = saturate(value.z);
74         }
75
76         make_constant(value);
77 }
78
79 void ConstantFolder::make_zero() const
80 {
81         if(output->type() == SocketType::FLOAT) {
82                 make_constant(0.0f);
83         }
84         else if(SocketType::is_float3(output->type())) {
85                 make_constant(make_float3(0.0f, 0.0f, 0.0f));
86         }
87         else {
88                 assert(0);
89         }
90 }
91
92 void ConstantFolder::make_one() const
93 {
94         if(output->type() == SocketType::FLOAT) {
95                 make_constant(1.0f);
96         }
97         else if(SocketType::is_float3(output->type())) {
98                 make_constant(make_float3(1.0f, 1.0f, 1.0f));
99         }
100         else {
101                 assert(0);
102         }
103 }
104
105 void ConstantFolder::bypass(ShaderOutput *new_output) const
106 {
107         assert(new_output);
108
109         VLOG(1) << "Folding " << node->name << "::" << output->name() << " to socket " << new_output->parent->name << "::" << new_output->name() << ".";
110
111         /* Remove all outgoing links from socket and connect them to new_output instead.
112          * The graph->relink method affects node inputs, so it's not safe to use in constant
113          * folding if the node has multiple outputs and will thus be folded multiple times. */
114         vector<ShaderInput*> outputs = output->links;
115
116         graph->disconnect(output);
117
118         foreach(ShaderInput *sock, outputs) {
119                 graph->connect(new_output, sock);
120         }
121 }
122
123 void ConstantFolder::discard() const
124 {
125         assert(output->type() == SocketType::CLOSURE);
126
127         VLOG(1) << "Discarding closure " << node->name << ".";
128
129         graph->disconnect(output);
130 }
131
132 void ConstantFolder::bypass_or_discard(ShaderInput *input) const
133 {
134         assert(input->type() == SocketType::CLOSURE);
135
136         if(input->link) {
137                 bypass(input->link);
138         }
139         else {
140                 discard();
141         }
142 }
143
144 bool ConstantFolder::try_bypass_or_make_constant(ShaderInput *input, bool clamp) const
145 {
146         if(input->type() != output->type()) {
147                 return false;
148         }
149         else if(!input->link) {
150                 if(input->type() == SocketType::FLOAT) {
151                         make_constant_clamp(node->get_float(input->socket_type), clamp);
152                         return true;
153                 }
154                 else if(SocketType::is_float3(input->type())) {
155                         make_constant_clamp(node->get_float3(input->socket_type), clamp);
156                         return true;
157                 }
158         }
159         else if(!clamp) {
160                 bypass(input->link);
161                 return true;
162         }
163         else {
164                 /* disconnect other inputs if we can't fully bypass due to clamp */
165                 foreach(ShaderInput *other, node->inputs) {
166                         if(other != input && other->link) {
167                                 graph->disconnect(other);
168                         }
169                 }
170         }
171
172         return false;
173 }
174
175 bool ConstantFolder::is_zero(ShaderInput *input) const
176 {
177         if(!input->link) {
178                 if(input->type() == SocketType::FLOAT) {
179                         return node->get_float(input->socket_type) == 0.0f;
180                 }
181                 else if(SocketType::is_float3(input->type())) {
182                         return node->get_float3(input->socket_type) ==
183                                make_float3(0.0f, 0.0f, 0.0f);
184                 }
185         }
186
187         return false;
188 }
189
190 bool ConstantFolder::is_one(ShaderInput *input) const
191 {
192         if(!input->link) {
193                 if(input->type() == SocketType::FLOAT) {
194                         return node->get_float(input->socket_type) == 1.0f;
195                 }
196                 else if(SocketType::is_float3(input->type())) {
197                         return node->get_float3(input->socket_type) ==
198                                make_float3(1.0f, 1.0f, 1.0f);
199                 }
200         }
201
202         return false;
203 }
204
205 /* Specific nodes */
206
207 void ConstantFolder::fold_mix(NodeMix type, bool clamp) const
208 {
209     ShaderInput *fac_in = node->input("Fac");
210     ShaderInput *color1_in = node->input("Color1");
211     ShaderInput *color2_in = node->input("Color2");
212
213         float fac = saturate(node->get_float(fac_in->socket_type));
214         bool fac_is_zero = !fac_in->link && fac == 0.0f;
215         bool fac_is_one = !fac_in->link && fac == 1.0f;
216
217         /* remove no-op node when factor is 0.0 */
218         if(fac_is_zero) {
219                 /* note that some of the modes will clamp out of bounds values even without use_clamp */
220                 if(!(type == NODE_MIX_LIGHT || type == NODE_MIX_DODGE || type == NODE_MIX_BURN)) {
221                         if(try_bypass_or_make_constant(color1_in, clamp)) {
222                                 return;
223                         }
224                 }
225         }
226
227         switch(type) {
228                 case NODE_MIX_BLEND:
229                         /* remove useless mix colors nodes */
230                         if(color1_in->link && color2_in->link) {
231                                 if(color1_in->link == color2_in->link) {
232                                         try_bypass_or_make_constant(color1_in, clamp);
233                                         break;
234                                 }
235                         }
236                         else if(!color1_in->link && !color2_in->link) {
237                                 float3 color1 = node->get_float3(color1_in->socket_type);
238                                 float3 color2 = node->get_float3(color2_in->socket_type);
239                                 if(color1 == color2) {
240                                         try_bypass_or_make_constant(color1_in, clamp);
241                                         break;
242                                 }
243                         }
244                         /* remove no-op mix color node when factor is 1.0 */
245                         if(fac_is_one) {
246                                 try_bypass_or_make_constant(color2_in, clamp);
247                                 break;
248                         }
249                         break;
250                 case NODE_MIX_ADD:
251                         /* 0 + X (fac 1) == X */
252                         if(is_zero(color1_in) && fac_is_one) {
253                                 try_bypass_or_make_constant(color2_in, clamp);
254                         }
255                         /* X + 0 (fac ?) == X */
256                         else if(is_zero(color2_in)) {
257                                 try_bypass_or_make_constant(color1_in, clamp);
258                         }
259                         break;
260                 case NODE_MIX_SUB:
261                         /* X - 0 (fac ?) == X */
262                         if(is_zero(color2_in)) {
263                                 try_bypass_or_make_constant(color1_in, clamp);
264                         }
265                         /* X - X (fac 1) == 0 */
266                         else if(color1_in->link && color1_in->link == color2_in->link && fac_is_one) {
267                                 make_zero();
268                         }
269                         break;
270                 case NODE_MIX_MUL:
271                         /* X * 1 (fac ?) == X, 1 * X (fac 1) == X */
272                         if(is_one(color1_in) && fac_is_one) {
273                                 try_bypass_or_make_constant(color2_in, clamp);
274                         }
275                         else if(is_one(color2_in)) {
276                                 try_bypass_or_make_constant(color1_in, clamp);
277                         }
278                         /* 0 * ? (fac ?) == 0, ? * 0 (fac 1) == 0 */
279                         else if(is_zero(color1_in)) {
280                                 make_zero();
281                         }
282                         else if(is_zero(color2_in) && fac_is_one) {
283                                 make_zero();
284                         }
285                         break;
286                 case NODE_MIX_DIV:
287                         /* X / 1 (fac ?) == X */
288                         if(is_one(color2_in)) {
289                                 try_bypass_or_make_constant(color1_in, clamp);
290                         }
291                         /* 0 / ? (fac ?) == 0 */
292                         else if(is_zero(color1_in)) {
293                                 make_zero();
294                         }
295                         break;
296                 default:
297                         break;
298         }
299 }
300
301 void ConstantFolder::fold_math(NodeMath type, bool clamp) const
302 {
303         ShaderInput *value1_in = node->input("Value1");
304         ShaderInput *value2_in = node->input("Value2");
305
306         switch(type) {
307                 case NODE_MATH_ADD:
308                         /* X + 0 == 0 + X == X */
309                         if(is_zero(value1_in)) {
310                                 try_bypass_or_make_constant(value2_in, clamp);
311                         }
312                         else if(is_zero(value2_in)) {
313                                 try_bypass_or_make_constant(value1_in, clamp);
314                         }
315                         break;
316                 case NODE_MATH_SUBTRACT:
317                         /* X - 0 == X */
318                         if(is_zero(value2_in)) {
319                                 try_bypass_or_make_constant(value1_in, clamp);
320                         }
321                         break;
322                 case NODE_MATH_MULTIPLY:
323                         /* X * 1 == 1 * X == X */
324                         if(is_one(value1_in)) {
325                                 try_bypass_or_make_constant(value2_in, clamp);
326                         }
327                         else if(is_one(value2_in)) {
328                                 try_bypass_or_make_constant(value1_in, clamp);
329                         }
330                         /* X * 0 == 0 * X == 0 */
331                         else if(is_zero(value1_in) || is_zero(value2_in)) {
332                                 make_zero();
333                         }
334                         break;
335                 case NODE_MATH_DIVIDE:
336                         /* X / 1 == X */
337                         if(is_one(value2_in)) {
338                                 try_bypass_or_make_constant(value1_in, clamp);
339                         }
340                         /* 0 / X == 0 */
341                         else if(is_zero(value1_in)) {
342                                 make_zero();
343                         }
344                         break;
345                 case NODE_MATH_POWER:
346                         /* 1 ^ X == X ^ 0 == 1 */
347                         if(is_one(value1_in) || is_zero(value2_in)) {
348                                 make_one();
349                         }
350                         /* X ^ 1 == X */
351                         else if(is_one(value2_in)) {
352                                 try_bypass_or_make_constant(value1_in, clamp);
353                         }
354                 default:
355                         break;
356         }
357 }
358
359 void ConstantFolder::fold_vector_math(NodeVectorMath type) const
360 {
361         ShaderInput *vector1_in = node->input("Vector1");
362         ShaderInput *vector2_in = node->input("Vector2");
363
364         switch(type) {
365                 case NODE_VECTOR_MATH_ADD:
366                         /* X + 0 == 0 + X == X */
367                         if(is_zero(vector1_in)) {
368                                 try_bypass_or_make_constant(vector2_in);
369                         }
370                         else if(is_zero(vector2_in)) {
371                                 try_bypass_or_make_constant(vector1_in);
372                         }
373                         break;
374                 case NODE_VECTOR_MATH_SUBTRACT:
375                         /* X - 0 == X */
376                         if(is_zero(vector2_in)) {
377                                 try_bypass_or_make_constant(vector1_in);
378                         }
379                         break;
380                 case NODE_VECTOR_MATH_DOT_PRODUCT:
381                 case NODE_VECTOR_MATH_CROSS_PRODUCT:
382                         /* X * 0 == 0 * X == 0 */
383                         if(is_zero(vector1_in) || is_zero(vector2_in)) {
384                                 make_zero();
385                         }
386                         break;
387                 default:
388                         break;
389         }
390 }
391
392 CCL_NAMESPACE_END