Cycles: add unit tests for supported constant folding rules.
[blender.git] / intern / cycles / render / constant_fold.cpp
1 /*
2  * Copyright 2011-2013 Blender Foundation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "constant_fold.h"
18 #include "graph.h"
19
20 #include "util_foreach.h"
21 #include "util_logging.h"
22
23 CCL_NAMESPACE_BEGIN
24
25 ConstantFolder::ConstantFolder(ShaderGraph *graph, ShaderNode *node, ShaderOutput *output)
26 : graph(graph), node(node), output(output)
27 {
28 }
29
30 bool ConstantFolder::all_inputs_constant() const
31 {
32         foreach(ShaderInput *input, node->inputs) {
33                 if(input->link) {
34                         return false;
35                 }
36         }
37
38         return true;
39 }
40
41 void ConstantFolder::make_constant(float value) const
42 {
43         VLOG(1) << "Folding " << node->name << "::" << output->name() << " to constant (" << value << ").";
44
45         foreach(ShaderInput *sock, output->links) {
46                 sock->set(value);
47         }
48
49         graph->disconnect(output);
50 }
51
52 void ConstantFolder::make_constant(float3 value) const
53 {
54         VLOG(1) << "Folding " << node->name << "::" << output->name() << " to constant " << value << ".";
55
56         foreach(ShaderInput *sock, output->links) {
57                 sock->set(value);
58         }
59
60         graph->disconnect(output);
61 }
62
63 void ConstantFolder::make_constant_clamp(float value, bool clamp) const
64 {
65         make_constant(clamp ? saturate(value) : value);
66 }
67
68 void ConstantFolder::make_constant_clamp(float3 value, bool clamp) const
69 {
70         if(clamp) {
71                 value.x = saturate(value.x);
72                 value.y = saturate(value.y);
73                 value.z = saturate(value.z);
74         }
75
76         make_constant(value);
77 }
78
79 void ConstantFolder::make_zero() const
80 {
81         if(output->type() == SocketType::FLOAT) {
82                 make_constant(0.0f);
83         }
84         else if(SocketType::is_float3(output->type())) {
85                 make_constant(make_float3(0.0f, 0.0f, 0.0f));
86         }
87         else {
88                 assert(0);
89         }
90 }
91
92 void ConstantFolder::bypass(ShaderOutput *new_output) const
93 {
94         assert(new_output);
95
96         VLOG(1) << "Folding " << node->name << "::" << output->name() << " to socket " << new_output->parent->name << "::" << new_output->name() << ".";
97
98         /* Remove all outgoing links from socket and connect them to new_output instead.
99          * The graph->relink method affects node inputs, so it's not safe to use in constant
100          * folding if the node has multiple outputs and will thus be folded multiple times. */
101         vector<ShaderInput*> outputs = output->links;
102
103         graph->disconnect(output);
104
105         foreach(ShaderInput *sock, outputs) {
106                 graph->connect(new_output, sock);
107         }
108 }
109
110 void ConstantFolder::discard() const
111 {
112         assert(output->type() == SocketType::CLOSURE);
113
114         VLOG(1) << "Discarding closure " << node->name << ".";
115
116         graph->disconnect(output);
117 }
118
119 void ConstantFolder::bypass_or_discard(ShaderInput *input) const
120 {
121         assert(input->type() == SocketType::CLOSURE);
122
123         if(input->link) {
124                 bypass(input->link);
125         }
126         else {
127                 discard();
128         }
129 }
130
131 bool ConstantFolder::try_bypass_or_make_constant(ShaderInput *input, bool clamp) const
132 {
133         if(input->type() != output->type()) {
134                 return false;
135         }
136         else if(!input->link) {
137                 if(input->type() == SocketType::FLOAT) {
138                         make_constant_clamp(node->get_float(input->socket_type), clamp);
139                         return true;
140                 }
141                 else if(SocketType::is_float3(input->type())) {
142                         make_constant_clamp(node->get_float3(input->socket_type), clamp);
143                         return true;
144                 }
145         }
146         else if(!clamp) {
147                 bypass(input->link);
148                 return true;
149         }
150
151         return false;
152 }
153
154 bool ConstantFolder::is_zero(ShaderInput *input) const
155 {
156         if(!input->link) {
157                 if(input->type() == SocketType::FLOAT) {
158                         return node->get_float(input->socket_type) == 0.0f;
159                 }
160                 else if(SocketType::is_float3(input->type())) {
161                         return node->get_float3(input->socket_type) ==
162                                make_float3(0.0f, 0.0f, 0.0f);
163                 }
164         }
165
166         return false;
167 }
168
169 bool ConstantFolder::is_one(ShaderInput *input) const
170 {
171         if(!input->link) {
172                 if(input->type() == SocketType::FLOAT) {
173                         return node->get_float(input->socket_type) == 1.0f;
174                 }
175                 else if(SocketType::is_float3(input->type())) {
176                         return node->get_float3(input->socket_type) ==
177                                make_float3(1.0f, 1.0f, 1.0f);
178                 }
179         }
180
181         return false;
182 }
183
184 /* Specific nodes */
185
186 void ConstantFolder::fold_mix(NodeMix type, bool clamp) const
187 {
188     ShaderInput *fac_in = node->input("Fac");
189     ShaderInput *color1_in = node->input("Color1");
190     ShaderInput *color2_in = node->input("Color2");
191
192         float fac = saturate(node->get_float(fac_in->socket_type));
193         bool fac_is_zero = !fac_in->link && fac == 0.0f;
194         bool fac_is_one = !fac_in->link && fac == 1.0f;
195
196         /* remove no-op node when factor is 0.0 */
197         if(fac_is_zero) {
198                 /* note that some of the modes will clamp out of bounds values even without use_clamp */
199                 if(!(type == NODE_MIX_LIGHT || type == NODE_MIX_DODGE || type == NODE_MIX_BURN)) {
200                         if(try_bypass_or_make_constant(color1_in, clamp)) {
201                                 return;
202                         }
203                 }
204         }
205
206         switch(type) {
207                 case NODE_MIX_BLEND:
208                         /* remove useless mix colors nodes */
209                         if(color1_in->link && color2_in->link) {
210                                 if(color1_in->link == color2_in->link) {
211                                         try_bypass_or_make_constant(color1_in, clamp);
212                                         break;
213                                 }
214                         }
215                         else if(!color1_in->link && !color2_in->link) {
216                                 float3 color1 = node->get_float3(color1_in->socket_type);
217                                 float3 color2 = node->get_float3(color2_in->socket_type);
218                                 if(color1 == color2) {
219                                         try_bypass_or_make_constant(color1_in, clamp);
220                                         break;
221                                 }
222                         }
223                         /* remove no-op mix color node when factor is 1.0 */
224                         if(fac_is_one) {
225                                 try_bypass_or_make_constant(color2_in, clamp);
226                                 break;
227                         }
228                         break;
229                 case NODE_MIX_ADD:
230                         /* 0 + X (fac 1) == X */
231                         if(is_zero(color1_in) && fac_is_one) {
232                                 try_bypass_or_make_constant(color2_in, clamp);
233                         }
234                         /* X + 0 (fac ?) == X */
235                         else if(is_zero(color2_in)) {
236                                 try_bypass_or_make_constant(color1_in, clamp);
237                         }
238                         break;
239                 case NODE_MIX_SUB:
240                         /* X - 0 (fac ?) == X */
241                         if(is_zero(color2_in)) {
242                                 try_bypass_or_make_constant(color1_in, clamp);
243                         }
244                         /* X - X (fac 1) == 0 */
245                         else if(color1_in->link && color1_in->link == color2_in->link && fac_is_one) {
246                                 make_zero();
247                         }
248                         break;
249                 case NODE_MIX_MUL:
250                         /* X * 1 (fac ?) == X, 1 * X (fac 1) == X */
251                         if(is_one(color1_in) && fac_is_one) {
252                                 try_bypass_or_make_constant(color2_in, clamp);
253                         }
254                         else if(is_one(color2_in)) {
255                                 try_bypass_or_make_constant(color1_in, clamp);
256                         }
257                         /* 0 * ? (fac ?) == 0, ? * 0 (fac 1) == 0 */
258                         else if(is_zero(color1_in)) {
259                                 make_zero();
260                         }
261                         else if(is_zero(color2_in) && fac_is_one) {
262                                 make_zero();
263                         }
264                         break;
265                 case NODE_MIX_DIV:
266                         /* X / 1 (fac ?) == X */
267                         if(is_one(color2_in)) {
268                                 try_bypass_or_make_constant(color1_in, clamp);
269                         }
270                         /* 0 / ? (fac ?) == 0 */
271                         else if(is_zero(color1_in)) {
272                                 make_zero();
273                         }
274                         break;
275                 default:
276                         break;
277         }
278 }
279
280 void ConstantFolder::fold_math(NodeMath type, bool clamp) const
281 {
282         ShaderInput *value1_in = node->input("Value1");
283         ShaderInput *value2_in = node->input("Value2");
284
285         switch(type) {
286                 case NODE_MATH_ADD:
287                         /* X + 0 == 0 + X == X */
288                         if(is_zero(value1_in)) {
289                                 try_bypass_or_make_constant(value2_in, clamp);
290                         }
291                         else if(is_zero(value2_in)) {
292                                 try_bypass_or_make_constant(value1_in, clamp);
293                         }
294                         break;
295                 case NODE_MATH_SUBTRACT:
296                         /* X - 0 == X */
297                         if(is_zero(value2_in)) {
298                                 try_bypass_or_make_constant(value1_in, clamp);
299                         }
300                         break;
301                 case NODE_MATH_MULTIPLY:
302                         /* X * 1 == 1 * X == X */
303                         if(is_one(value1_in)) {
304                                 try_bypass_or_make_constant(value2_in, clamp);
305                         }
306                         else if(is_one(value2_in)) {
307                                 try_bypass_or_make_constant(value1_in, clamp);
308                         }
309                         /* X * 0 == 0 * X == 0 */
310                         else if(is_zero(value1_in) || is_zero(value2_in)) {
311                                 make_zero();
312                         }
313                         break;
314                 case NODE_MATH_DIVIDE:
315                         /* X / 1 == X */
316                         if(is_one(value2_in)) {
317                                 try_bypass_or_make_constant(value1_in, clamp);
318                         }
319                         /* 0 / X == 0 */
320                         else if(is_zero(value1_in)) {
321                                 make_zero();
322                         }
323                         break;
324                 default:
325                         break;
326         }
327 }
328
329 void ConstantFolder::fold_vector_math(NodeVectorMath type) const
330 {
331         ShaderInput *vector1_in = node->input("Vector1");
332         ShaderInput *vector2_in = node->input("Vector2");
333
334         switch(type) {
335                 case NODE_VECTOR_MATH_ADD:
336                         /* X + 0 == 0 + X == X */
337                         if(is_zero(vector1_in)) {
338                                 try_bypass_or_make_constant(vector2_in);
339                         }
340                         else if(is_zero(vector2_in)) {
341                                 try_bypass_or_make_constant(vector1_in);
342                         }
343                         break;
344                 case NODE_VECTOR_MATH_SUBTRACT:
345                         /* X - 0 == X */
346                         if(is_zero(vector2_in)) {
347                                 try_bypass_or_make_constant(vector1_in);
348                         }
349                         break;
350                 case NODE_VECTOR_MATH_DOT_PRODUCT:
351                 case NODE_VECTOR_MATH_CROSS_PRODUCT:
352                         /* X * 0 == 0 * X == 0 */
353                         if(is_zero(vector1_in) || is_zero(vector2_in)) {
354                                 make_zero();
355                         }
356                         break;
357                 default:
358                         break;
359         }
360 }
361
362 CCL_NAMESPACE_END