Code refactor: move more memory allocation logic into device API.
[blender-staging.git] / intern / cycles / device / device_network.h
1 /*
2  * Copyright 2011-2013 Blender Foundation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __DEVICE_NETWORK_H__
18 #define __DEVICE_NETWORK_H__
19
20 #ifdef WITH_NETWORK
21
22 #include <boost/archive/text_iarchive.hpp>
23 #include <boost/archive/text_oarchive.hpp>
24 #include <boost/archive/binary_iarchive.hpp>
25 #include <boost/archive/binary_oarchive.hpp>
26 #include <boost/array.hpp>
27 #include <boost/asio.hpp>
28 #include <boost/bind.hpp>
29 #include <boost/serialization/vector.hpp>
30 #include <boost/thread.hpp>
31
32 #include <iostream>
33 #include <sstream>
34 #include <deque>
35
36 #include "render/buffers.h"
37
38 #include "util/util_foreach.h"
39 #include "util/util_list.h"
40 #include "util/util_map.h"
41 #include "util/util_param.h"
42 #include "util/util_string.h"
43
44 CCL_NAMESPACE_BEGIN
45
46 using std::cout;
47 using std::cerr;
48 using std::hex;
49 using std::setw;
50 using std::exception;
51
52 using boost::asio::ip::tcp;
53
54 static const int SERVER_PORT = 5120;
55 static const int DISCOVER_PORT = 5121;
56 static const string DISCOVER_REQUEST_MSG = "REQUEST_RENDER_SERVER_IP";
57 static const string DISCOVER_REPLY_MSG = "REPLY_RENDER_SERVER_IP";
58
59 #if 0
60 typedef boost::archive::text_oarchive o_archive;
61 typedef boost::archive::text_iarchive i_archive;
62 #else
63 typedef boost::archive::binary_oarchive o_archive;
64 typedef boost::archive::binary_iarchive i_archive;
65 #endif
66
67 /* Serialization of device memory */
68
69 class network_device_memory : public device_memory
70 {
71 public:
72         network_device_memory(Device *device)
73         : device_memory(device, "", MEM_READ_ONLY)
74         {
75         }
76
77         ~network_device_memory()
78         {
79                 device_pointer = 0;
80         };
81
82         vector<char> local_data;
83 };
84
85 /* Common netowrk error function / object for both DeviceNetwork and DeviceServer*/
86 class NetworkError {
87 public:
88         NetworkError() {
89                 error = "";
90                 error_count = 0;
91         }
92
93         ~NetworkError() {}
94
95         void network_error(const string& message) {
96                 error = message;
97                 error_count += 1;
98         }
99
100         bool have_error() {
101                 return true ? error_count > 0 : false;
102         }
103
104 private:
105         string error;
106         int error_count;
107 };
108
109
110 /* Remote procedure call Send */
111
112 class RPCSend {
113 public:
114         RPCSend(tcp::socket& socket_, NetworkError* e, const string& name_ = "")
115         : name(name_), socket(socket_), archive(archive_stream), sent(false)
116         {
117                 archive & name_;
118                 error_func = e;
119                 fprintf(stderr, "rpc send %s\n", name.c_str());
120         }
121
122         ~RPCSend()
123         {
124         }
125
126         void add(const device_memory& mem)
127         {
128                 archive & mem.data_type & mem.data_elements & mem.data_size;
129                 archive & mem.data_width & mem.data_height & mem.data_depth & mem.device_pointer;
130                 archive & mem.type & string(mem.name);
131                 archive & mem.interpolation & mem.extension;
132                 archive & mem.device_pointer;
133         }
134
135         template<typename T> void add(const T& data)
136         {
137                 archive & data;
138         }
139
140         void add(const DeviceTask& task)
141         {
142                 int type = (int)task.type;
143                 archive & type & task.x & task.y & task.w & task.h;
144                 archive & task.rgba_byte & task.rgba_half & task.buffer & task.sample & task.num_samples;
145                 archive & task.offset & task.stride;
146                 archive & task.shader_input & task.shader_output & task.shader_eval_type;
147                 archive & task.shader_x & task.shader_w;
148                 archive & task.need_finish_queue;
149         }
150
151         void add(const RenderTile& tile)
152         {
153                 archive & tile.x & tile.y & tile.w & tile.h;
154                 archive & tile.start_sample & tile.num_samples & tile.sample;
155                 archive & tile.resolution & tile.offset & tile.stride;
156                 archive & tile.buffer;
157         }
158
159         void write()
160         {
161                 boost::system::error_code error;
162
163                 /* get string from stream */
164                 string archive_str = archive_stream.str();
165
166                 /* first send fixed size header with size of following data */
167                 ostringstream header_stream;
168                 header_stream << setw(8) << hex << archive_str.size();
169                 string header_str = header_stream.str();
170
171                 boost::asio::write(socket,
172                         boost::asio::buffer(header_str),
173                         boost::asio::transfer_all(), error);
174
175                 if(error.value())
176                         error_func->network_error(error.message());
177
178                 /* then send actual data */
179                 boost::asio::write(socket,
180                         boost::asio::buffer(archive_str),
181                         boost::asio::transfer_all(), error);
182                 
183                 if(error.value())
184                         error_func->network_error(error.message());
185
186                 sent = true;
187         }
188
189         void write_buffer(void *buffer, size_t size)
190         {
191                 boost::system::error_code error;
192
193                 boost::asio::write(socket,
194                         boost::asio::buffer(buffer, size),
195                         boost::asio::transfer_all(), error);
196                 
197                 if(error.value())
198                         error_func->network_error(error.message());
199         }
200
201 protected:
202         string name;
203         tcp::socket& socket;
204         ostringstream archive_stream;
205         o_archive archive;
206         bool sent;
207         NetworkError *error_func;
208 };
209
210 /* Remote procedure call Receive */
211
212 class RPCReceive {
213 public:
214         RPCReceive(tcp::socket& socket_, NetworkError* e )
215         : socket(socket_), archive_stream(NULL), archive(NULL)
216         {
217                 error_func = e;
218                 /* read head with fixed size */
219                 vector<char> header(8);
220                 boost::system::error_code error;
221                 size_t len = boost::asio::read(socket, boost::asio::buffer(header), error);
222
223                 if(error.value()) {
224                         error_func->network_error(error.message());
225                 }
226
227                 /* verify if we got something */
228                 if(len == header.size()) {
229                         /* decode header */
230                         string header_str(&header[0], header.size());
231                         istringstream header_stream(header_str);
232
233                         size_t data_size;
234
235                         if((header_stream >> hex >> data_size)) {
236
237                                 vector<char> data(data_size);
238                                 size_t len = boost::asio::read(socket, boost::asio::buffer(data), error);
239
240                                 if(error.value())
241                                         error_func->network_error(error.message());
242
243
244                                 if(len == data_size) {
245                                         archive_str = (data.size())? string(&data[0], data.size()): string("");
246
247                                         archive_stream = new istringstream(archive_str);
248                                         archive = new i_archive(*archive_stream);
249
250                                         *archive & name;
251                                         fprintf(stderr, "rpc receive %s\n", name.c_str());
252                                 }
253                                 else {
254                                         error_func->network_error("Network receive error: data size doesn't match header");
255                                 }
256                         }
257                         else {
258                                 error_func->network_error("Network receive error: can't decode data size from header");
259                         }
260                 }
261                 else {
262                         error_func->network_error("Network receive error: invalid header size");
263                 }
264         }
265
266         ~RPCReceive()
267         {
268                 delete archive;
269                 delete archive_stream;
270         }
271
272         void read(network_device_memory& mem, string& name)
273         {
274                 *archive & mem.data_type & mem.data_elements & mem.data_size;
275                 *archive & mem.data_width & mem.data_height & mem.data_depth & mem.device_pointer;
276                 *archive & mem.type & name;
277                 *archive & mem.interpolation & mem.extension;
278                 *archive & mem.device_pointer;
279
280                 mem.name = name.c_str();
281                 mem.data_pointer = 0;
282
283                 /* Can't transfer OpenGL texture over network. */
284                 if(mem.type == MEM_PIXELS) {
285                         mem.type = MEM_WRITE_ONLY;
286                 }
287         }
288
289         template<typename T> void read(T& data)
290         {
291                 *archive & data;
292         }
293
294         void read_buffer(void *buffer, size_t size)
295         {
296                 boost::system::error_code error;
297                 size_t len = boost::asio::read(socket, boost::asio::buffer(buffer, size), error);
298
299                 if(error.value()) {
300                         error_func->network_error(error.message());
301                 }
302
303                 if(len != size)
304                         cout << "Network receive error: buffer size doesn't match expected size\n";
305         }
306
307         void read(DeviceTask& task)
308         {
309                 int type;
310
311                 *archive & type & task.x & task.y & task.w & task.h;
312                 *archive & task.rgba_byte & task.rgba_half & task.buffer & task.sample & task.num_samples;
313                 *archive & task.offset & task.stride;
314                 *archive & task.shader_input & task.shader_output & task.shader_eval_type;
315                 *archive & task.shader_x & task.shader_w;
316                 *archive & task.need_finish_queue;
317
318                 task.type = (DeviceTask::Type)type;
319         }
320
321         void read(RenderTile& tile)
322         {
323                 *archive & tile.x & tile.y & tile.w & tile.h;
324                 *archive & tile.start_sample & tile.num_samples & tile.sample;
325                 *archive & tile.resolution & tile.offset & tile.stride;
326                 *archive & tile.buffer;
327
328                 tile.buffers = NULL;
329         }
330
331         string name;
332
333 protected:
334         tcp::socket& socket;
335         string archive_str;
336         istringstream *archive_stream;
337         i_archive *archive;
338         NetworkError *error_func;
339 };
340
341 /* Server auto discovery */
342
343 class ServerDiscovery {
344 public:
345         explicit ServerDiscovery(bool discover = false)
346         : listen_socket(io_service), collect_servers(false)
347         {
348                 /* setup listen socket */
349                 listen_endpoint.address(boost::asio::ip::address_v4::any());
350                 listen_endpoint.port(DISCOVER_PORT);
351
352                 listen_socket.open(listen_endpoint.protocol());
353
354                 boost::asio::socket_base::reuse_address option(true);
355                 listen_socket.set_option(option);
356
357                 listen_socket.bind(listen_endpoint);
358
359                 /* setup receive callback */
360                 async_receive();
361
362                 /* start server discovery */
363                 if(discover) {
364                         collect_servers = true;
365                         servers.clear();
366
367                         broadcast_message(DISCOVER_REQUEST_MSG);
368                 }
369
370                 /* start thread */
371                 work = new boost::asio::io_service::work(io_service);
372                 thread = new boost::thread(boost::bind(&boost::asio::io_service::run, &io_service));
373         }
374
375         ~ServerDiscovery()
376         {
377                 io_service.stop();
378                 thread->join();
379                 delete thread;
380                 delete work;
381         }
382
383         vector<string> get_server_list()
384         {
385                 vector<string> result;
386
387                 mutex.lock();
388                 result = vector<string>(servers.begin(), servers.end());
389                 mutex.unlock();
390
391                 return result;
392         }
393
394 private:
395         void handle_receive_from(const boost::system::error_code& error, size_t size)
396         {
397                 if(error) {
398                         cout << "Server discovery receive error: " << error.message() << "\n";
399                         return;
400                 }
401
402                 if(size > 0) {
403                         string msg = string(receive_buffer, size);
404
405                         /* handle incoming message */
406                         if(collect_servers) {
407                                 if(msg == DISCOVER_REPLY_MSG) {
408                                         string address = receive_endpoint.address().to_string();
409
410                                         mutex.lock();
411
412                                         /* add address if it's not already in the list */
413                                         bool found = std::find(servers.begin(), servers.end(),
414                                                                address) != servers.end();
415
416                                         if(!found)
417                                                 servers.push_back(address);
418
419                                         mutex.unlock();
420                                 }
421                         }
422                         else {
423                                 /* reply to request */
424                                 if(msg == DISCOVER_REQUEST_MSG)
425                                         broadcast_message(DISCOVER_REPLY_MSG);
426                         }
427                 }
428
429                 async_receive();
430         }
431
432         void async_receive()
433         {
434                 listen_socket.async_receive_from(
435                         boost::asio::buffer(receive_buffer), receive_endpoint,
436                         boost::bind(&ServerDiscovery::handle_receive_from, this,
437                         boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred));
438         }
439
440         void broadcast_message(const string& msg)
441         {
442                 /* setup broadcast socket */
443                 boost::asio::ip::udp::socket socket(io_service);
444
445                 socket.open(boost::asio::ip::udp::v4());
446
447                 boost::asio::socket_base::broadcast option(true);
448                 socket.set_option(option);
449
450                 boost::asio::ip::udp::endpoint broadcast_endpoint(
451                         boost::asio::ip::address::from_string("255.255.255.255"), DISCOVER_PORT);
452
453                 /* broadcast message */
454                 socket.send_to(boost::asio::buffer(msg), broadcast_endpoint);
455         }
456
457         /* network service and socket */
458         boost::asio::io_service io_service;
459         boost::asio::ip::udp::endpoint listen_endpoint;
460         boost::asio::ip::udp::socket listen_socket;
461
462         /* threading */
463         boost::thread *thread;
464         boost::asio::io_service::work *work;
465         boost::mutex mutex;
466
467         /* buffer and endpoint for receiving messages */
468         char receive_buffer[256];
469         boost::asio::ip::udp::endpoint receive_endpoint;
470         
471         // os, version, devices, status, host name, group name, ip as far as fields go
472         struct ServerInfo {
473                 string cycles_version;
474                 string os;
475                 int device_count;
476                 string status;
477                 string host_name;
478                 string group_name;
479                 string host_addr;
480         };
481
482         /* collection of server addresses in list */
483         bool collect_servers;
484         vector<string> servers;
485 };
486
487 CCL_NAMESPACE_END
488
489 #endif
490
491 #endif /* __DEVICE_NETWORK_H__ */
492