Cycles: Make all #include statements relative to cycles source directory
[blender.git] / intern / cycles / device / device_network.h
1 /*
2  * Copyright 2011-2013 Blender Foundation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __DEVICE_NETWORK_H__
18 #define __DEVICE_NETWORK_H__
19
20 #ifdef WITH_NETWORK
21
22 #include <boost/archive/text_iarchive.hpp>
23 #include <boost/archive/text_oarchive.hpp>
24 #include <boost/archive/binary_iarchive.hpp>
25 #include <boost/archive/binary_oarchive.hpp>
26 #include <boost/array.hpp>
27 #include <boost/asio.hpp>
28 #include <boost/bind.hpp>
29 #include <boost/serialization/vector.hpp>
30 #include <boost/thread.hpp>
31
32 #include <iostream>
33 #include <sstream>
34 #include <deque>
35
36 #include "render/buffers.h"
37
38 #include "util/util_foreach.h"
39 #include "util/util_list.h"
40 #include "util/util_map.h"
41 #include "util/util_string.h"
42
43 CCL_NAMESPACE_BEGIN
44
45 using std::cout;
46 using std::cerr;
47 using std::hex;
48 using std::setw;
49 using std::exception;
50
51 using boost::asio::ip::tcp;
52
53 static const int SERVER_PORT = 5120;
54 static const int DISCOVER_PORT = 5121;
55 static const string DISCOVER_REQUEST_MSG = "REQUEST_RENDER_SERVER_IP";
56 static const string DISCOVER_REPLY_MSG = "REPLY_RENDER_SERVER_IP";
57
58 #if 0
59 typedef boost::archive::text_oarchive o_archive;
60 typedef boost::archive::text_iarchive i_archive;
61 #else
62 typedef boost::archive::binary_oarchive o_archive;
63 typedef boost::archive::binary_iarchive i_archive;
64 #endif
65
66 /* Serialization of device memory */
67
68 class network_device_memory : public device_memory
69 {
70 public:
71         network_device_memory() {}
72         ~network_device_memory() { device_pointer = 0; };
73
74         vector<char> local_data;
75 };
76
77 /* Common netowrk error function / object for both DeviceNetwork and DeviceServer*/
78 class NetworkError {
79 public:
80         NetworkError() {
81                 error = "";
82                 error_count = 0;
83         }
84
85         ~NetworkError() {}
86
87         void network_error(const string& message) {
88                 error = message;
89                 error_count += 1;
90         }
91
92         bool have_error() {
93                 return true ? error_count > 0 : false;
94         }
95
96 private:
97         string error;
98         int error_count;
99 };
100
101
102 /* Remote procedure call Send */
103
104 class RPCSend {
105 public:
106         RPCSend(tcp::socket& socket_, NetworkError* e, const string& name_ = "")
107         : name(name_), socket(socket_), archive(archive_stream), sent(false)
108         {
109                 archive & name_;
110                 error_func = e;
111                 fprintf(stderr, "rpc send %s\n", name.c_str());
112         }
113
114         ~RPCSend()
115         {
116         }
117
118         void add(const device_memory& mem)
119         {
120                 archive & mem.data_type & mem.data_elements & mem.data_size;
121                 archive & mem.data_width & mem.data_height & mem.data_depth & mem.device_pointer;
122         }
123
124         template<typename T> void add(const T& data)
125         {
126                 archive & data;
127         }
128
129         void add(const DeviceTask& task)
130         {
131                 int type = (int)task.type;
132                 archive & type & task.x & task.y & task.w & task.h;
133                 archive & task.rgba_byte & task.rgba_half & task.buffer & task.sample & task.num_samples;
134                 archive & task.offset & task.stride;
135                 archive & task.shader_input & task.shader_output & task.shader_output_luma & task.shader_eval_type;
136                 archive & task.shader_x & task.shader_w;
137                 archive & task.need_finish_queue;
138         }
139
140         void add(const RenderTile& tile)
141         {
142                 archive & tile.x & tile.y & tile.w & tile.h;
143                 archive & tile.start_sample & tile.num_samples & tile.sample;
144                 archive & tile.resolution & tile.offset & tile.stride;
145                 archive & tile.buffer & tile.rng_state;
146         }
147
148         void write()
149         {
150                 boost::system::error_code error;
151
152                 /* get string from stream */
153                 string archive_str = archive_stream.str();
154
155                 /* first send fixed size header with size of following data */
156                 ostringstream header_stream;
157                 header_stream << setw(8) << hex << archive_str.size();
158                 string header_str = header_stream.str();
159
160                 boost::asio::write(socket,
161                         boost::asio::buffer(header_str),
162                         boost::asio::transfer_all(), error);
163
164                 if(error.value())
165                         error_func->network_error(error.message());
166
167                 /* then send actual data */
168                 boost::asio::write(socket,
169                         boost::asio::buffer(archive_str),
170                         boost::asio::transfer_all(), error);
171                 
172                 if(error.value())
173                         error_func->network_error(error.message());
174
175                 sent = true;
176         }
177
178         void write_buffer(void *buffer, size_t size)
179         {
180                 boost::system::error_code error;
181
182                 boost::asio::write(socket,
183                         boost::asio::buffer(buffer, size),
184                         boost::asio::transfer_all(), error);
185                 
186                 if(error.value())
187                         error_func->network_error(error.message());
188         }
189
190 protected:
191         string name;
192         tcp::socket& socket;
193         ostringstream archive_stream;
194         o_archive archive;
195         bool sent;
196         NetworkError *error_func;
197 };
198
199 /* Remote procedure call Receive */
200
201 class RPCReceive {
202 public:
203         RPCReceive(tcp::socket& socket_, NetworkError* e )
204         : socket(socket_), archive_stream(NULL), archive(NULL)
205         {
206                 error_func = e;
207                 /* read head with fixed size */
208                 vector<char> header(8);
209                 boost::system::error_code error;
210                 size_t len = boost::asio::read(socket, boost::asio::buffer(header), error);
211
212                 if(error.value()) {
213                         error_func->network_error(error.message());
214                 }
215
216                 /* verify if we got something */
217                 if(len == header.size()) {
218                         /* decode header */
219                         string header_str(&header[0], header.size());
220                         istringstream header_stream(header_str);
221
222                         size_t data_size;
223
224                         if((header_stream >> hex >> data_size)) {
225
226                                 vector<char> data(data_size);
227                                 size_t len = boost::asio::read(socket, boost::asio::buffer(data), error);
228
229                                 if(error.value())
230                                         error_func->network_error(error.message());
231
232
233                                 if(len == data_size) {
234                                         archive_str = (data.size())? string(&data[0], data.size()): string("");
235
236                                         archive_stream = new istringstream(archive_str);
237                                         archive = new i_archive(*archive_stream);
238
239                                         *archive & name;
240                                         fprintf(stderr, "rpc receive %s\n", name.c_str());
241                                 }
242                                 else {
243                                         error_func->network_error("Network receive error: data size doesn't match header");
244                                 }
245                         }
246                         else {
247                                 error_func->network_error("Network receive error: can't decode data size from header");
248                         }
249                 }
250                 else {
251                         error_func->network_error("Network receive error: invalid header size");
252                 }
253         }
254
255         ~RPCReceive()
256         {
257                 delete archive;
258                 delete archive_stream;
259         }
260
261         void read(network_device_memory& mem)
262         {
263                 *archive & mem.data_type & mem.data_elements & mem.data_size;
264                 *archive & mem.data_width & mem.data_height & mem.data_depth & mem.device_pointer;
265
266                 mem.data_pointer = 0;
267         }
268
269         template<typename T> void read(T& data)
270         {
271                 *archive & data;
272         }
273
274         void read_buffer(void *buffer, size_t size)
275         {
276                 boost::system::error_code error;
277                 size_t len = boost::asio::read(socket, boost::asio::buffer(buffer, size), error);
278
279                 if(error.value()) {
280                         error_func->network_error(error.message());
281                 }
282
283                 if(len != size)
284                         cout << "Network receive error: buffer size doesn't match expected size\n";
285         }
286
287         void read(DeviceTask& task)
288         {
289                 int type;
290
291                 *archive & type & task.x & task.y & task.w & task.h;
292                 *archive & task.rgba_byte & task.rgba_half & task.buffer & task.sample & task.num_samples;
293                 *archive & task.offset & task.stride;
294                 *archive & task.shader_input & task.shader_output & task.shader_output_luma & task.shader_eval_type;
295                 *archive & task.shader_x & task.shader_w;
296                 *archive & task.need_finish_queue;
297
298                 task.type = (DeviceTask::Type)type;
299         }
300
301         void read(RenderTile& tile)
302         {
303                 *archive & tile.x & tile.y & tile.w & tile.h;
304                 *archive & tile.start_sample & tile.num_samples & tile.sample;
305                 *archive & tile.resolution & tile.offset & tile.stride;
306                 *archive & tile.buffer & tile.rng_state;
307
308                 tile.buffers = NULL;
309         }
310
311         string name;
312
313 protected:
314         tcp::socket& socket;
315         string archive_str;
316         istringstream *archive_stream;
317         i_archive *archive;
318         NetworkError *error_func;
319 };
320
321 /* Server auto discovery */
322
323 class ServerDiscovery {
324 public:
325         explicit ServerDiscovery(bool discover = false)
326         : listen_socket(io_service), collect_servers(false)
327         {
328                 /* setup listen socket */
329                 listen_endpoint.address(boost::asio::ip::address_v4::any());
330                 listen_endpoint.port(DISCOVER_PORT);
331
332                 listen_socket.open(listen_endpoint.protocol());
333
334                 boost::asio::socket_base::reuse_address option(true);
335                 listen_socket.set_option(option);
336
337                 listen_socket.bind(listen_endpoint);
338
339                 /* setup receive callback */
340                 async_receive();
341
342                 /* start server discovery */
343                 if(discover) {
344                         collect_servers = true;
345                         servers.clear();
346
347                         broadcast_message(DISCOVER_REQUEST_MSG);
348                 }
349
350                 /* start thread */
351                 work = new boost::asio::io_service::work(io_service);
352                 thread = new boost::thread(boost::bind(&boost::asio::io_service::run, &io_service));
353         }
354
355         ~ServerDiscovery()
356         {
357                 io_service.stop();
358                 thread->join();
359                 delete thread;
360                 delete work;
361         }
362
363         vector<string> get_server_list()
364         {
365                 vector<string> result;
366
367                 mutex.lock();
368                 result = vector<string>(servers.begin(), servers.end());
369                 mutex.unlock();
370
371                 return result;
372         }
373
374 private:
375         void handle_receive_from(const boost::system::error_code& error, size_t size)
376         {
377                 if(error) {
378                         cout << "Server discovery receive error: " << error.message() << "\n";
379                         return;
380                 }
381
382                 if(size > 0) {
383                         string msg = string(receive_buffer, size);
384
385                         /* handle incoming message */
386                         if(collect_servers) {
387                                 if(msg == DISCOVER_REPLY_MSG) {
388                                         string address = receive_endpoint.address().to_string();
389
390                                         mutex.lock();
391
392                                         /* add address if it's not already in the list */
393                                         bool found = std::find(servers.begin(), servers.end(),
394                                                                address) != servers.end();
395
396                                         if(!found)
397                                                 servers.push_back(address);
398
399                                         mutex.unlock();
400                                 }
401                         }
402                         else {
403                                 /* reply to request */
404                                 if(msg == DISCOVER_REQUEST_MSG)
405                                         broadcast_message(DISCOVER_REPLY_MSG);
406                         }
407                 }
408
409                 async_receive();
410         }
411
412         void async_receive()
413         {
414                 listen_socket.async_receive_from(
415                         boost::asio::buffer(receive_buffer), receive_endpoint,
416                         boost::bind(&ServerDiscovery::handle_receive_from, this,
417                         boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred));
418         }
419
420         void broadcast_message(const string& msg)
421         {
422                 /* setup broadcast socket */
423                 boost::asio::ip::udp::socket socket(io_service);
424
425                 socket.open(boost::asio::ip::udp::v4());
426
427                 boost::asio::socket_base::broadcast option(true);
428                 socket.set_option(option);
429
430                 boost::asio::ip::udp::endpoint broadcast_endpoint(
431                         boost::asio::ip::address::from_string("255.255.255.255"), DISCOVER_PORT);
432
433                 /* broadcast message */
434                 socket.send_to(boost::asio::buffer(msg), broadcast_endpoint);
435         }
436
437         /* network service and socket */
438         boost::asio::io_service io_service;
439         boost::asio::ip::udp::endpoint listen_endpoint;
440         boost::asio::ip::udp::socket listen_socket;
441
442         /* threading */
443         boost::thread *thread;
444         boost::asio::io_service::work *work;
445         boost::mutex mutex;
446
447         /* buffer and endpoint for receiving messages */
448         char receive_buffer[256];
449         boost::asio::ip::udp::endpoint receive_endpoint;
450         
451         // os, version, devices, status, host name, group name, ip as far as fields go
452         struct ServerInfo {
453                 string cycles_version;
454                 string os;
455                 int device_count;
456                 string status;
457                 string host_name;
458                 string group_name;
459                 string host_addr;
460         };
461
462         /* collection of server addresses in list */
463         bool collect_servers;
464         vector<string> servers;
465 };
466
467 CCL_NAMESPACE_END
468
469 #endif
470
471 #endif /* __DEVICE_NETWORK_H__ */
472