Merged changes in the trunk up to revision 54802.
[blender.git] / intern / cycles / device / device_network.cpp
1 /*
2  * Copyright 2011, Blender Foundation.
3  *
4  * This program is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU General Public License
6  * as published by the Free Software Foundation; either version 2
7  * of the License, or (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software Foundation,
16  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
17  */
18
19 #include "device.h"
20 #include "device_intern.h"
21 #include "device_network.h"
22
23 #include "util_foreach.h"
24
25 CCL_NAMESPACE_BEGIN
26
27 #ifdef WITH_NETWORK
28
29 class NetworkDevice : public Device
30 {
31 public:
32         boost::asio::io_service io_service;
33         tcp::socket socket;
34         device_ptr mem_counter;
35         DeviceTask the_task; /* todo: handle multiple tasks */
36
37         NetworkDevice(Stats &stats, const char *address)
38         : Device(stats), socket(io_service)
39         {
40                 stringstream portstr;
41                 portstr << SERVER_PORT;
42
43                 tcp::resolver resolver(io_service);
44                 tcp::resolver::query query(address, portstr.str());
45                 tcp::resolver::iterator endpoint_iterator = resolver.resolve(query);
46                 tcp::resolver::iterator end;
47
48                 boost::system::error_code error = boost::asio::error::host_not_found;
49                 while(error && endpoint_iterator != end)
50                 {
51                         socket.close();
52                         socket.connect(*endpoint_iterator++, error);
53                 }
54
55                 if(error)
56                         throw boost::system::system_error(error);
57
58                 mem_counter = 0;
59         }
60
61         ~NetworkDevice()
62         {
63                 RPCSend snd(socket, "stop");
64                 snd.write();
65         }
66
67         void mem_alloc(device_memory& mem, MemoryType type)
68         {
69                 mem.device_pointer = ++mem_counter;
70
71                 RPCSend snd(socket, "mem_alloc");
72
73                 snd.add(mem);
74                 snd.add(type);
75                 snd.write();
76         }
77
78         void mem_copy_to(device_memory& mem)
79         {
80                 RPCSend snd(socket, "mem_copy_to");
81
82                 snd.add(mem);
83                 snd.write();
84                 snd.write_buffer((void*)mem.data_pointer, mem.memory_size());
85         }
86
87         void mem_copy_from(device_memory& mem, int y, int w, int h, int elem)
88         {
89                 RPCSend snd(socket, "mem_copy_from");
90
91                 snd.add(mem);
92                 snd.add(y);
93                 snd.add(w);
94                 snd.add(h);
95                 snd.add(elem);
96                 snd.write();
97
98                 RPCReceive rcv(socket);
99                 rcv.read_buffer((void*)mem.data_pointer, mem.memory_size());
100         }
101
102         void mem_zero(device_memory& mem)
103         {
104                 RPCSend snd(socket, "mem_zero");
105
106                 snd.add(mem);
107                 snd.write();
108         }
109
110         void mem_free(device_memory& mem)
111         {
112                 if(mem.device_pointer) {
113                         RPCSend snd(socket, "mem_free");
114
115                         snd.add(mem);
116                         snd.write();
117
118                         mem.device_pointer = 0;
119                 }
120         }
121
122         void const_copy_to(const char *name, void *host, size_t size)
123         {
124                 RPCSend snd(socket, "const_copy_to");
125
126                 string name_string(name);
127
128                 snd.add(name_string);
129                 snd.add(size);
130                 snd.write();
131                 snd.write_buffer(host, size);
132         }
133
134         void tex_alloc(const char *name, device_memory& mem, bool interpolation, bool periodic)
135         {
136                 mem.device_pointer = ++mem_counter;
137
138                 RPCSend snd(socket, "tex_alloc");
139
140                 string name_string(name);
141
142                 snd.add(name_string);
143                 snd.add(mem);
144                 snd.add(interpolation);
145                 snd.add(periodic);
146                 snd.write();
147                 snd.write_buffer((void*)mem.data_pointer, mem.memory_size());
148         }
149
150         void tex_free(device_memory& mem)
151         {
152                 if(mem.device_pointer) {
153                         RPCSend snd(socket, "tex_free");
154
155                         snd.add(mem);
156                         snd.write();
157
158                         mem.device_pointer = 0;
159                 }
160         }
161
162         void task_add(DeviceTask& task)
163         {
164                 the_task = task;
165
166                 RPCSend snd(socket, "task_add");
167                 snd.add(task);
168                 snd.write();
169         }
170
171         void task_wait()
172         {
173                 RPCSend snd(socket, "task_wait");
174                 snd.write();
175
176                 list<RenderTile> the_tiles;
177
178                 /* todo: run this threaded for connecting to multiple clients */
179                 for(;;) {
180                         RPCReceive rcv(socket);
181                         RenderTile tile;
182
183                         if(rcv.name == "acquire_tile") {
184                                 /* todo: watch out for recursive calls! */
185                                 if(the_task.acquire_tile(this, tile)) { /* write return as bool */
186                                         the_tiles.push_back(tile);
187
188                                         RPCSend snd(socket, "acquire_tile");
189                                         snd.add(tile);
190                                         snd.write();
191                                 }
192                                 else {
193                                         RPCSend snd(socket, "acquire_tile_none");
194                                         snd.write();
195                                 }
196                         }
197                         else if(rcv.name == "release_tile") {
198                                 rcv.read(tile);
199
200                                 for(list<RenderTile>::iterator it = the_tiles.begin(); it != the_tiles.end(); it++) {
201                                         if(tile.x == it->x && tile.y == it->y && tile.start_sample == it->start_sample) {
202                                                 tile.buffers = it->buffers;
203                                                 the_tiles.erase(it);
204                                                 break;
205                                         }
206                                 }
207
208                                 assert(tile.buffers != NULL);
209
210                                 the_task.release_tile(tile);
211
212                                 RPCSend snd(socket, "release_tile");
213                                 snd.write();
214                         }
215                         else if(rcv.name == "task_wait_done")
216                                 break;
217                 }
218         }
219
220         void task_cancel()
221         {
222                 RPCSend snd(socket, "task_cancel");
223                 snd.write();
224         }
225 };
226
227 Device *device_network_create(DeviceInfo& info, Stats &stats, const char *address)
228 {
229         return new NetworkDevice(stats, address);
230 }
231
232 void device_network_info(vector<DeviceInfo>& devices)
233 {
234         DeviceInfo info;
235
236         info.type = DEVICE_NETWORK;
237         info.description = "Network Device";
238         info.id = "NETWORK";
239         info.num = 0;
240         info.advanced_shading = true; /* todo: get this info from device */
241         info.pack_images = false;
242
243         devices.push_back(info);
244 }
245
246 class DeviceServer {
247 public:
248         DeviceServer(Device *device_, tcp::socket& socket_)
249         : device(device_), socket(socket_)
250         {
251         }
252
253         void listen()
254         {
255                 /* receive remote function calls */
256                 for(;;) {
257                         RPCReceive rcv(socket);
258
259                         if(rcv.name == "stop")
260                                 break;
261
262                         process(rcv);
263                 }
264         }
265
266 protected:
267         void process(RPCReceive& rcv)
268         {
269                 // fprintf(stderr, "receive process %s\n", rcv.name.c_str());
270
271                 if(rcv.name == "mem_alloc") {
272                         MemoryType type;
273                         network_device_memory mem;
274                         device_ptr remote_pointer;
275
276                         rcv.read(mem);
277                         rcv.read(type);
278
279                         /* todo: CPU needs mem.data_pointer */
280
281                         remote_pointer = mem.device_pointer;
282
283                         mem_data[remote_pointer] = vector<uint8_t>();
284                         mem_data[remote_pointer].resize(mem.memory_size());
285                         if(mem.memory_size())
286                                 mem.data_pointer = (device_ptr)&(mem_data[remote_pointer][0]);
287                         else
288                                 mem.data_pointer = 0;
289
290                         device->mem_alloc(mem, type);
291
292                         ptr_map[remote_pointer] = mem.device_pointer;
293                         ptr_imap[mem.device_pointer] = remote_pointer;
294                 }
295                 else if(rcv.name == "mem_copy_to") {
296                         network_device_memory mem;
297
298                         rcv.read(mem);
299
300                         device_ptr remote_pointer = mem.device_pointer;
301                         mem.data_pointer = (device_ptr)&(mem_data[remote_pointer][0]);
302
303                         rcv.read_buffer((uint8_t*)mem.data_pointer, mem.memory_size());
304
305                         mem.device_pointer = ptr_map[remote_pointer];
306
307                         device->mem_copy_to(mem);
308                 }
309                 else if(rcv.name == "mem_copy_from") {
310                         network_device_memory mem;
311                         int y, w, h, elem;
312
313                         rcv.read(mem);
314                         rcv.read(y);
315                         rcv.read(w);
316                         rcv.read(h);
317                         rcv.read(elem);
318
319                         device_ptr remote_pointer = mem.device_pointer;
320                         mem.device_pointer = ptr_map[remote_pointer];
321                         mem.data_pointer = (device_ptr)&(mem_data[remote_pointer][0]);
322
323                         device->mem_copy_from(mem, y, w, h, elem);
324
325                         RPCSend snd(socket);
326                         snd.write();
327                         snd.write_buffer((uint8_t*)mem.data_pointer, mem.memory_size());
328                 }
329                 else if(rcv.name == "mem_zero") {
330                         network_device_memory mem;
331                         
332                         rcv.read(mem);
333                         device_ptr remote_pointer = mem.device_pointer;
334                         mem.device_pointer = ptr_map[mem.device_pointer];
335                         mem.data_pointer = (device_ptr)&(mem_data[remote_pointer][0]);
336
337                         device->mem_zero(mem);
338                 }
339                 else if(rcv.name == "mem_free") {
340                         network_device_memory mem;
341                         device_ptr remote_pointer;
342
343                         rcv.read(mem);
344
345                         remote_pointer = mem.device_pointer;
346                         mem.device_pointer = ptr_map[mem.device_pointer];
347                         ptr_map.erase(remote_pointer);
348                         ptr_imap.erase(mem.device_pointer);
349                         mem_data.erase(remote_pointer);
350
351                         device->mem_free(mem);
352                 }
353                 else if(rcv.name == "const_copy_to") {
354                         string name_string;
355                         size_t size;
356
357                         rcv.read(name_string);
358                         rcv.read(size);
359
360                         vector<char> host_vector(size);
361                         rcv.read_buffer(&host_vector[0], size);
362
363                         device->const_copy_to(name_string.c_str(), &host_vector[0], size);
364                 }
365                 else if(rcv.name == "tex_alloc") {
366                         network_device_memory mem;
367                         string name;
368                         bool interpolation;
369                         bool periodic;
370                         device_ptr remote_pointer;
371
372                         rcv.read(name);
373                         rcv.read(mem);
374                         rcv.read(interpolation);
375                         rcv.read(periodic);
376
377                         remote_pointer = mem.device_pointer;
378
379                         mem_data[remote_pointer] = vector<uint8_t>();
380                         mem_data[remote_pointer].resize(mem.memory_size());
381                         if(mem.memory_size())
382                                 mem.data_pointer = (device_ptr)&(mem_data[remote_pointer][0]);
383                         else
384                                 mem.data_pointer = 0;
385
386                         rcv.read_buffer((uint8_t*)mem.data_pointer, mem.memory_size());
387
388                         device->tex_alloc(name.c_str(), mem, interpolation, periodic);
389
390                         ptr_map[remote_pointer] = mem.device_pointer;
391                         ptr_imap[mem.device_pointer] = remote_pointer;
392                 }
393                 else if(rcv.name == "tex_free") {
394                         network_device_memory mem;
395                         device_ptr remote_pointer;
396
397                         rcv.read(mem);
398
399                         remote_pointer = mem.device_pointer;
400                         mem.device_pointer = ptr_map[mem.device_pointer];
401                         ptr_map.erase(remote_pointer);
402                         ptr_map.erase(mem.device_pointer);
403                         mem_data.erase(remote_pointer);
404
405                         device->tex_free(mem);
406                 }
407                 else if(rcv.name == "task_add") {
408                         DeviceTask task;
409
410                         rcv.read(task);
411
412                         if(task.buffer) task.buffer = ptr_map[task.buffer];
413                         if(task.rgba) task.rgba = ptr_map[task.rgba];
414                         if(task.shader_input) task.shader_input = ptr_map[task.shader_input];
415                         if(task.shader_output) task.shader_output = ptr_map[task.shader_output];
416
417                         task.acquire_tile = function_bind(&DeviceServer::task_acquire_tile, this, _1, _2);
418                         task.release_tile = function_bind(&DeviceServer::task_release_tile, this, _1);
419                         task.update_progress_sample = function_bind(&DeviceServer::task_update_progress_sample, this);
420                         task.update_tile_sample = function_bind(&DeviceServer::task_update_tile_sample, this, _1);
421                         task.get_cancel = function_bind(&DeviceServer::task_get_cancel, this);
422
423                         device->task_add(task);
424                 }
425                 else if(rcv.name == "task_wait") {
426                         device->task_wait();
427
428                         RPCSend snd(socket, "task_wait_done");
429                         snd.write();
430                 }
431                 else if(rcv.name == "task_cancel") {
432                         device->task_cancel();
433                 }
434         }
435
436         bool task_acquire_tile(Device *device, RenderTile& tile)
437         {
438                 thread_scoped_lock acquire_lock(acquire_mutex);
439
440                 bool result = false;
441
442                 RPCSend snd(socket, "acquire_tile");
443                 snd.write();
444
445                 while(1) {
446                         RPCReceive rcv(socket);
447
448                         if(rcv.name == "acquire_tile") {
449                                 rcv.read(tile);
450
451                                 if(tile.buffer) tile.buffer = ptr_map[tile.buffer];
452                                 if(tile.rng_state) tile.rng_state = ptr_map[tile.rng_state];
453                                 if(tile.rgba) tile.rgba = ptr_map[tile.rgba];
454
455                                 result = true;
456                                 break;
457                         }
458                         else if(rcv.name == "acquire_tile_none")
459                                 break;
460                         else
461                                 process(rcv);
462                 }
463
464                 return result;
465         }
466
467         void task_update_progress_sample()
468         {
469                 ; /* skip */
470         }
471
472         void task_update_tile_sample(RenderTile&)
473         {
474                 ; /* skip */
475         }
476
477         void task_release_tile(RenderTile& tile)
478         {
479                 thread_scoped_lock acquire_lock(acquire_mutex);
480
481                 if(tile.buffer) tile.buffer = ptr_imap[tile.buffer];
482                 if(tile.rng_state) tile.rng_state = ptr_imap[tile.rng_state];
483                 if(tile.rgba) tile.rgba = ptr_imap[tile.rgba];
484
485                 RPCSend snd(socket, "release_tile");
486                 snd.add(tile);
487                 snd.write();
488
489                 while(1) {
490                         RPCReceive rcv(socket);
491
492                         if(rcv.name == "release_tile")
493                                 break;
494                         else
495                                 process(rcv);
496                 }
497         }
498
499         bool task_get_cancel()
500         {
501                 return false;
502         }
503
504         /* properties */
505         Device *device;
506         tcp::socket& socket;
507
508         /* mapping of remote to local pointer */
509         map<device_ptr, device_ptr> ptr_map;
510         map<device_ptr, device_ptr> ptr_imap;
511         map<device_ptr, vector<uint8_t> > mem_data;
512
513         thread_mutex acquire_mutex;
514
515         /* todo: free memory and device (osl) on network error */
516 };
517
518 void Device::server_run()
519 {
520         try {
521                 /* starts thread that responds to discovery requests */
522                 ServerDiscovery discovery;
523
524                 for(;;) {
525                         /* accept connection */
526                         boost::asio::io_service io_service;
527                         tcp::acceptor acceptor(io_service, tcp::endpoint(tcp::v4(), SERVER_PORT));
528
529                         tcp::socket socket(io_service);
530                         acceptor.accept(socket);
531
532                         string remote_address = socket.remote_endpoint().address().to_string();
533                         printf("Connected to remote client at: %s\n", remote_address.c_str());
534
535                         DeviceServer server(this, socket);
536                         server.listen();
537
538                         printf("Disconnected.\n");
539                 }
540         }
541         catch(exception& e) {
542                 fprintf(stderr, "Network server exception: %s\n", e.what());
543         }
544 }
545
546 #endif
547
548 CCL_NAMESPACE_END
549