Fix click-drag regression in fix for T86116
[blender.git] / source / blender / nodes / intern / node_tree_multi_function.cc
1 /*
2  * This program is free software; you can redistribute it and/or
3  * modify it under the terms of the GNU General Public License
4  * as published by the Free Software Foundation; either version 2
5  * of the License, or (at your option) any later version.
6  *
7  * This program is distributed in the hope that it will be useful,
8  * but WITHOUT ANY WARRANTY; without even the implied warranty of
9  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  * GNU General Public License for more details.
11  *
12  * You should have received a copy of the GNU General Public License
13  * along with this program; if not, write to the Free Software Foundation,
14  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
15  */
16
17 #include "NOD_node_tree_multi_function.hh"
18
19 #include "FN_multi_function_network_evaluation.hh"
20
21 #include "BLI_color.hh"
22 #include "BLI_float2.hh"
23 #include "BLI_float3.hh"
24
25 namespace blender::nodes {
26
27 const fn::MultiFunction &NodeMFNetworkBuilder::get_default_fn(StringRef name)
28 {
29   Vector<fn::MFDataType, 10> input_types;
30   Vector<fn::MFDataType, 10> output_types;
31
32   for (const DInputSocket *dsocket : dnode_.inputs()) {
33     if (dsocket->is_available()) {
34       std::optional<fn::MFDataType> data_type = socket_mf_type_get(*dsocket->bsocket()->typeinfo);
35       if (data_type.has_value()) {
36         input_types.append(*data_type);
37       }
38     }
39   }
40   for (const DOutputSocket *dsocket : dnode_.outputs()) {
41     if (dsocket->is_available()) {
42       std::optional<fn::MFDataType> data_type = socket_mf_type_get(*dsocket->bsocket()->typeinfo);
43       if (data_type.has_value()) {
44         output_types.append(*data_type);
45       }
46     }
47   }
48
49   const fn::MultiFunction &fn = this->construct_fn<fn::CustomMF_DefaultOutput>(
50       name, input_types, output_types);
51   return fn;
52 }
53
54 static void insert_dummy_node(CommonMFNetworkBuilderData &common, const DNode &dnode)
55 {
56   constexpr int stack_capacity = 10;
57
58   Vector<fn::MFDataType, stack_capacity> input_types;
59   Vector<StringRef, stack_capacity> input_names;
60   Vector<const DInputSocket *, stack_capacity> input_dsockets;
61
62   for (const DInputSocket *dsocket : dnode.inputs()) {
63     if (dsocket->is_available()) {
64       std::optional<fn::MFDataType> data_type = socket_mf_type_get(*dsocket->bsocket()->typeinfo);
65       if (data_type.has_value()) {
66         input_types.append(*data_type);
67         input_names.append(dsocket->name());
68         input_dsockets.append(dsocket);
69       }
70     }
71   }
72
73   Vector<fn::MFDataType, stack_capacity> output_types;
74   Vector<StringRef, stack_capacity> output_names;
75   Vector<const DOutputSocket *, stack_capacity> output_dsockets;
76
77   for (const DOutputSocket *dsocket : dnode.outputs()) {
78     if (dsocket->is_available()) {
79       std::optional<fn::MFDataType> data_type = socket_mf_type_get(*dsocket->bsocket()->typeinfo);
80       if (data_type.has_value()) {
81         output_types.append(*data_type);
82         output_names.append(dsocket->name());
83         output_dsockets.append(dsocket);
84       }
85     }
86   }
87
88   fn::MFDummyNode &dummy_node = common.network.add_dummy(
89       dnode.name(), input_types, output_types, input_names, output_names);
90
91   common.network_map.add(input_dsockets, dummy_node.inputs());
92   common.network_map.add(output_dsockets, dummy_node.outputs());
93 }
94
95 static bool has_data_sockets(const DNode &dnode)
96 {
97   for (const DInputSocket *socket : dnode.inputs()) {
98     if (socket_is_mf_data_socket(*socket->bsocket()->typeinfo)) {
99       return true;
100     }
101   }
102   for (const DOutputSocket *socket : dnode.outputs()) {
103     if (socket_is_mf_data_socket(*socket->bsocket()->typeinfo)) {
104       return true;
105     }
106   }
107   return false;
108 }
109
110 /**
111  * Expands all function nodes in the multi-function network. Nodes that don't have an expand
112  * function, but do have data sockets, will get corresponding dummy nodes.
113  */
114 static void insert_nodes(CommonMFNetworkBuilderData &common)
115 {
116   for (const DNode *dnode : common.tree.nodes()) {
117     const bNodeType *node_type = dnode->node_ref().bnode()->typeinfo;
118     if (node_type->expand_in_mf_network != nullptr) {
119       NodeMFNetworkBuilder builder{common, *dnode};
120       node_type->expand_in_mf_network(builder);
121     }
122     else if (has_data_sockets(*dnode)) {
123       insert_dummy_node(common, *dnode);
124     }
125   }
126 }
127
128 static void insert_group_inputs(CommonMFNetworkBuilderData &common)
129 {
130   for (const DGroupInput *group_input : common.tree.group_inputs()) {
131     bNodeSocket *bsocket = group_input->bsocket();
132     if (socket_is_mf_data_socket(*bsocket->typeinfo)) {
133       bNodeSocketType *socktype = bsocket->typeinfo;
134       BLI_assert(socktype->expand_in_mf_network != nullptr);
135
136       SocketMFNetworkBuilder builder{common, *group_input};
137       socktype->expand_in_mf_network(builder);
138
139       fn::MFOutputSocket *from_socket = builder.built_socket();
140       BLI_assert(from_socket != nullptr);
141       common.network_map.add(*group_input, *from_socket);
142     }
143   }
144 }
145
146 static fn::MFOutputSocket *try_find_origin(CommonMFNetworkBuilderData &common,
147                                            const DInputSocket &to_dsocket)
148 {
149   Span<const DOutputSocket *> from_dsockets = to_dsocket.linked_sockets();
150   Span<const DGroupInput *> from_group_inputs = to_dsocket.linked_group_inputs();
151   int total_linked_amount = from_dsockets.size() + from_group_inputs.size();
152   BLI_assert(total_linked_amount <= 1);
153
154   if (total_linked_amount == 0) {
155     return nullptr;
156   }
157
158   if (from_dsockets.size() == 1) {
159     const DOutputSocket &from_dsocket = *from_dsockets[0];
160     if (!from_dsocket.is_available()) {
161       return nullptr;
162     }
163     if (socket_is_mf_data_socket(*from_dsocket.bsocket()->typeinfo)) {
164       return &common.network_map.lookup(from_dsocket);
165     }
166     return nullptr;
167   }
168
169   const DGroupInput &from_group_input = *from_group_inputs[0];
170   if (socket_is_mf_data_socket(*from_group_input.bsocket()->typeinfo)) {
171     return &common.network_map.lookup(from_group_input);
172   }
173   return nullptr;
174 }
175
176 template<typename From, typename To>
177 static void add_implicit_conversion(DataTypeConversions &conversions)
178 {
179   static fn::CustomMF_Convert<From, To> function;
180   conversions.add(fn::MFDataType::ForSingle<From>(), fn::MFDataType::ForSingle<To>(), function);
181 }
182
183 template<typename From, typename To, typename ConversionF>
184 static void add_implicit_conversion(DataTypeConversions &conversions,
185                                     StringRef name,
186                                     ConversionF conversion)
187 {
188   static fn::CustomMF_SI_SO<From, To> function{name, conversion};
189   conversions.add(fn::MFDataType::ForSingle<From>(), fn::MFDataType::ForSingle<To>(), function);
190 }
191
192 static DataTypeConversions create_implicit_conversions()
193 {
194   DataTypeConversions conversions;
195   add_implicit_conversion<float, float2>(conversions);
196   add_implicit_conversion<float, float3>(conversions);
197   add_implicit_conversion<float, int32_t>(conversions);
198   add_implicit_conversion<float, bool>(conversions);
199   add_implicit_conversion<float, Color4f>(
200       conversions, "float to Color4f", [](float a) { return Color4f(a, a, a, 1.0f); });
201
202   add_implicit_conversion<float2, float3>(
203       conversions, "float2 to float3", [](float2 a) { return float3(a.x, a.y, 0.0f); });
204   add_implicit_conversion<float2, float>(
205       conversions, "float2 to float", [](float2 a) { return a.length(); });
206   add_implicit_conversion<float2, int32_t>(
207       conversions, "float2 to int32_t", [](float2 a) { return (int32_t)a.length(); });
208   add_implicit_conversion<float2, bool>(
209       conversions, "float2 to bool", [](float2 a) { return a.length_squared() == 0.0f; });
210   add_implicit_conversion<float2, Color4f>(
211       conversions, "float2 to Color4f", [](float2 a) { return Color4f(a.x, a.y, 0.0f, 1.0f); });
212
213   add_implicit_conversion<float3, bool>(
214       conversions, "float3 to boolean", [](float3 a) { return a.length_squared() == 0.0f; });
215   add_implicit_conversion<float3, float>(
216       conversions, "Vector Length", [](float3 a) { return a.length(); });
217   add_implicit_conversion<float3, int32_t>(
218       conversions, "float3 to int32_t", [](float3 a) { return (int)a.length(); });
219   add_implicit_conversion<float3, float2>(conversions);
220   add_implicit_conversion<float3, Color4f>(
221       conversions, "float3 to Color4f", [](float3 a) { return Color4f(a.x, a.y, a.z, 1.0f); });
222
223   add_implicit_conversion<int32_t, bool>(conversions);
224   add_implicit_conversion<int32_t, float>(conversions);
225   add_implicit_conversion<int32_t, float2>(
226       conversions, "int32 to float2", [](int32_t a) { return float2((float)a); });
227   add_implicit_conversion<int32_t, float3>(
228       conversions, "int32 to float3", [](int32_t a) { return float3((float)a); });
229
230   add_implicit_conversion<bool, float>(conversions);
231   add_implicit_conversion<bool, int32_t>(conversions);
232   add_implicit_conversion<bool, float2>(
233       conversions, "boolean to float2", [](bool a) { return (a) ? float2(1.0f) : float2(0.0f); });
234   add_implicit_conversion<bool, float3>(
235       conversions, "boolean to float3", [](bool a) { return (a) ? float3(1.0f) : float3(0.0f); });
236   add_implicit_conversion<bool, Color4f>(conversions, "boolean to Color4f", [](bool a) {
237     return (a) ? Color4f(1.0f, 1.0f, 1.0f, 1.0f) : Color4f(0.0f, 0.0f, 0.0f, 1.0f);
238   });
239
240   add_implicit_conversion<Color4f, float>(
241       conversions, "Color4f to float", [](Color4f a) { return rgb_to_grayscale(a); });
242   add_implicit_conversion<Color4f, float2>(
243       conversions, "Color4f to float2", [](Color4f a) { return float2(a.r, a.g); });
244   add_implicit_conversion<Color4f, float3>(
245       conversions, "Color4f to float3", [](Color4f a) { return float3(a.r, a.g, a.b); });
246
247   return conversions;
248 }
249
250 const DataTypeConversions &get_implicit_type_conversions()
251 {
252   static const DataTypeConversions conversions = create_implicit_conversions();
253   return conversions;
254 }
255
256 void DataTypeConversions::convert(const CPPType &from_type,
257                                   const CPPType &to_type,
258                                   const void *from_value,
259                                   void *to_value) const
260 {
261   const fn::MultiFunction *fn = this->get_conversion(MFDataType::ForSingle(from_type),
262                                                      MFDataType::ForSingle(to_type));
263   BLI_assert(fn != nullptr);
264
265   fn::MFContextBuilder context;
266   fn::MFParamsBuilder params{*fn, 1};
267   params.add_readonly_single_input(fn::GSpan(from_type, from_value, 1));
268   params.add_uninitialized_single_output(fn::GMutableSpan(to_type, to_value, 1));
269   fn->call({0}, params, context);
270 }
271
272 static fn::MFOutputSocket &insert_default_value_for_type(CommonMFNetworkBuilderData &common,
273                                                          fn::MFDataType type)
274 {
275   const fn::MultiFunction *default_fn;
276   if (type.is_single()) {
277     default_fn = &common.resources.construct<fn::CustomMF_GenericConstant>(
278         AT, type.single_type(), type.single_type().default_value());
279   }
280   else {
281     default_fn = &common.resources.construct<fn::CustomMF_GenericConstantArray>(
282         AT, fn::GSpan(type.vector_base_type()));
283   }
284
285   fn::MFNode &node = common.network.add_function(*default_fn);
286   return node.output(0);
287 }
288
289 static void insert_links(CommonMFNetworkBuilderData &common)
290 {
291   for (const DInputSocket *to_dsocket : common.tree.input_sockets()) {
292     if (!to_dsocket->is_available()) {
293       continue;
294     }
295     if (!to_dsocket->is_linked()) {
296       continue;
297     }
298     if (!socket_is_mf_data_socket(*to_dsocket->bsocket()->typeinfo)) {
299       continue;
300     }
301
302     Span<fn::MFInputSocket *> to_sockets = common.network_map.lookup(*to_dsocket);
303     BLI_assert(to_sockets.size() >= 1);
304     fn::MFDataType to_type = to_sockets[0]->data_type();
305
306     fn::MFOutputSocket *from_socket = try_find_origin(common, *to_dsocket);
307     if (from_socket == nullptr) {
308       from_socket = &insert_default_value_for_type(common, to_type);
309     }
310
311     fn::MFDataType from_type = from_socket->data_type();
312
313     if (from_type != to_type) {
314       const fn::MultiFunction *conversion_fn = get_implicit_type_conversions().get_conversion(
315           from_type, to_type);
316       if (conversion_fn != nullptr) {
317         fn::MFNode &node = common.network.add_function(*conversion_fn);
318         common.network.add_link(*from_socket, node.input(0));
319         from_socket = &node.output(0);
320       }
321       else {
322         from_socket = &insert_default_value_for_type(common, to_type);
323       }
324     }
325
326     for (fn::MFInputSocket *to_socket : to_sockets) {
327       common.network.add_link(*from_socket, *to_socket);
328     }
329   }
330 }
331
332 static void insert_unlinked_input(CommonMFNetworkBuilderData &common, const DInputSocket &dsocket)
333 {
334   bNodeSocket *bsocket = dsocket.bsocket();
335   bNodeSocketType *socktype = bsocket->typeinfo;
336   BLI_assert(socktype->expand_in_mf_network != nullptr);
337
338   SocketMFNetworkBuilder builder{common, dsocket};
339   socktype->expand_in_mf_network(builder);
340
341   fn::MFOutputSocket *from_socket = builder.built_socket();
342   BLI_assert(from_socket != nullptr);
343
344   for (fn::MFInputSocket *to_socket : common.network_map.lookup(dsocket)) {
345     common.network.add_link(*from_socket, *to_socket);
346   }
347 }
348
349 static void insert_unlinked_inputs(CommonMFNetworkBuilderData &common)
350 {
351   Vector<const DInputSocket *> unlinked_data_inputs;
352   for (const DInputSocket *dsocket : common.tree.input_sockets()) {
353     if (dsocket->is_available()) {
354       if (socket_is_mf_data_socket(*dsocket->bsocket()->typeinfo)) {
355         if (!dsocket->is_linked()) {
356           insert_unlinked_input(common, *dsocket);
357         }
358       }
359     }
360   }
361 }
362
363 /**
364  * Expands all function nodes contained in the given node tree within the given multi-function
365  * network.
366  *
367  * Returns a mapping between the original node tree and the generated nodes/sockets for further
368  * processing.
369  */
370 MFNetworkTreeMap insert_node_tree_into_mf_network(fn::MFNetwork &network,
371                                                   const DerivedNodeTree &tree,
372                                                   ResourceCollector &resources)
373 {
374   MFNetworkTreeMap network_map{tree, network};
375
376   CommonMFNetworkBuilderData common{resources, network, network_map, tree};
377
378   insert_nodes(common);
379   insert_group_inputs(common);
380   insert_links(common);
381   insert_unlinked_inputs(common);
382
383   return network_map;
384 }
385
386 /**
387  * A single node is allowed to expand into multiple nodes before evaluation. Depending on what
388  * nodes it expands to, it belongs a different type of the ones below.
389  */
390 enum class NodeExpandType {
391   SingleFunctionNode,
392   MultipleFunctionNodes,
393   HasDummyNodes,
394 };
395
396 /**
397  * Checks how the given node expanded in the multi-function network. If it is only a single
398  * function node, the corresponding function is returned as well.
399  */
400 static NodeExpandType get_node_expand_type(MFNetworkTreeMap &network_map,
401                                            const DNode &dnode,
402                                            const fn::MultiFunction **r_single_function)
403 {
404   const fn::MFFunctionNode *single_function_node = nullptr;
405   bool has_multiple_nodes = false;
406   bool has_dummy_nodes = false;
407
408   auto check_mf_node = [&](fn::MFNode &mf_node) {
409     if (mf_node.is_function()) {
410       if (single_function_node == nullptr) {
411         single_function_node = &mf_node.as_function();
412       }
413       if (&mf_node != single_function_node) {
414         has_multiple_nodes = true;
415       }
416     }
417     else {
418       BLI_assert(mf_node.is_dummy());
419       has_dummy_nodes = true;
420     }
421   };
422
423   for (const DInputSocket *dsocket : dnode.inputs()) {
424     if (dsocket->is_available()) {
425       for (fn::MFInputSocket *mf_input : network_map.lookup(*dsocket)) {
426         check_mf_node(mf_input->node());
427       }
428     }
429   }
430   for (const DOutputSocket *dsocket : dnode.outputs()) {
431     if (dsocket->is_available()) {
432       fn::MFOutputSocket &mf_output = network_map.lookup(*dsocket);
433       check_mf_node(mf_output.node());
434     }
435   }
436
437   if (has_dummy_nodes) {
438     return NodeExpandType::HasDummyNodes;
439   }
440   if (has_multiple_nodes) {
441     return NodeExpandType::MultipleFunctionNodes;
442   }
443   *r_single_function = &single_function_node->function();
444   return NodeExpandType::SingleFunctionNode;
445 }
446
447 static const fn::MultiFunction &create_function_for_node_that_expands_into_multiple(
448     const DNode &dnode,
449     fn::MFNetwork &network,
450     MFNetworkTreeMap &network_map,
451     ResourceCollector &resources)
452 {
453   Vector<const fn::MFOutputSocket *> dummy_fn_inputs;
454   for (const DInputSocket *dsocket : dnode.inputs()) {
455     if (dsocket->is_available()) {
456       MFDataType data_type = *socket_mf_type_get(*dsocket->typeinfo());
457       fn::MFOutputSocket &fn_input = network.add_input(data_type.to_string(), data_type);
458       for (fn::MFInputSocket *mf_input : network_map.lookup(*dsocket)) {
459         network.add_link(fn_input, *mf_input);
460         dummy_fn_inputs.append(&fn_input);
461       }
462     }
463   }
464   Vector<const fn::MFInputSocket *> dummy_fn_outputs;
465   for (const DOutputSocket *dsocket : dnode.outputs()) {
466     if (dsocket->is_available()) {
467       fn::MFOutputSocket &mf_output = network_map.lookup(*dsocket);
468       MFDataType data_type = mf_output.data_type();
469       fn::MFInputSocket &fn_output = network.add_output(data_type.to_string(), data_type);
470       network.add_link(mf_output, fn_output);
471       dummy_fn_outputs.append(&fn_output);
472     }
473   }
474
475   fn::MFNetworkEvaluator &fn_evaluator = resources.construct<fn::MFNetworkEvaluator>(
476       __func__, std::move(dummy_fn_inputs), std::move(dummy_fn_outputs));
477   return fn_evaluator;
478 }
479
480 /**
481  * Returns a single multi-function for every node that supports it. This makes it easier to reuse
482  * the multi-function implementation of nodes in different contexts.
483  */
484 MultiFunctionByNode get_multi_function_per_node(const DerivedNodeTree &tree,
485                                                 ResourceCollector &resources)
486 {
487   /* Build a network that nodes can insert themselves into. However, the individual nodes are not
488    * connected. */
489   fn::MFNetwork &network = resources.construct<fn::MFNetwork>(__func__);
490   MFNetworkTreeMap network_map{tree, network};
491   MultiFunctionByNode functions_by_node;
492
493   CommonMFNetworkBuilderData common{resources, network, network_map, tree};
494
495   for (const DNode *dnode : tree.nodes()) {
496     const bNodeType *node_type = dnode->typeinfo();
497     if (node_type->expand_in_mf_network == nullptr) {
498       /* This node does not have a multi-function implementation. */
499       continue;
500     }
501
502     NodeMFNetworkBuilder builder{common, *dnode};
503     node_type->expand_in_mf_network(builder);
504
505     const fn::MultiFunction *single_function = nullptr;
506     const NodeExpandType expand_type = get_node_expand_type(network_map, *dnode, &single_function);
507
508     switch (expand_type) {
509       case NodeExpandType::HasDummyNodes: {
510         /* Dummy nodes cannot be executed, so skip them. */
511         break;
512       }
513       case NodeExpandType::SingleFunctionNode: {
514         /* This is the common case. Most nodes just expand to a single function. */
515         functions_by_node.add_new(dnode, single_function);
516         break;
517       }
518       case NodeExpandType::MultipleFunctionNodes: {
519         /* If a node expanded into multiple functions, a new function has to be created that
520          * combines those. */
521         const fn::MultiFunction &fn = create_function_for_node_that_expands_into_multiple(
522             *dnode, network, network_map, resources);
523         functions_by_node.add_new(dnode, &fn);
524         break;
525       }
526     }
527   }
528
529   return functions_by_node;
530 }
531
532 }  // namespace blender::nodes