Cycles: Calculate size of split state buffer kernel side
[blender.git] / intern / cycles / device / opencl / opencl_split.cpp
index 7e04c6fac2c958386c09c6045385a28b69c97baa..a44f5da3a32c5d18b153b8a87aa2d326ba1babad 100644 (file)
@@ -60,6 +60,7 @@ class OpenCLDeviceSplitKernel : public OpenCLDeviceBase
 public:
        DeviceSplitKernel *split_kernel;
        OpenCLProgram program_data_init;
+       OpenCLProgram program_state_buffer_size;
 
        OpenCLDeviceSplitKernel(DeviceInfo& info, Stats &stats, bool background_);
 
@@ -83,6 +84,13 @@ public:
                program_data_init.add_kernel(ustring("path_trace_data_init"));
                programs.push_back(&program_data_init);
 
+               program_state_buffer_size = OpenCLDeviceBase::OpenCLProgram(this,
+                                                 "split_state_buffer_size",
+                                                 "kernel_state_buffer_size.cl",
+                                                 get_build_options(this, requested_features));
+               program_state_buffer_size.add_kernel(ustring("path_trace_state_buffer_size"));
+               programs.push_back(&program_state_buffer_size);
+
                return split_kernel->load_kernels(requested_features);
        }
 
@@ -216,6 +224,41 @@ public:
                return kernel;
        }
 
+       virtual size_t state_buffer_size(device_memory& kg, device_memory& data, size_t num_threads)
+       {
+               device_vector<uint> size_buffer;
+               size_buffer.resize(1);
+               device->mem_alloc(NULL, size_buffer, MEM_READ_WRITE);
+
+               uint threads = num_threads;
+               device->kernel_set_args(device->program_state_buffer_size(), 0, kg, data, threads, size_buffer);
+
+               size_t global_size = 64;
+               device->ciErr = clEnqueueNDRangeKernel(device->cqCommandQueue,
+                                              device->program_state_buffer_size(),
+                                              1,
+                                              NULL,
+                                              &global_size,
+                                              NULL,
+                                              0,
+                                              NULL,
+                                              NULL);
+
+               device->opencl_assert_err(device->ciErr, "clEnqueueNDRangeKernel");
+
+               device->mem_copy_from(size_buffer, 0, 1, 1, sizeof(uint));
+               device->mem_free(size_buffer);
+
+               if(device->ciErr != CL_SUCCESS) {
+                       string message = string_printf("OpenCL error: %s in clEnqueueNDRangeKernel()",
+                                                      clewErrorString(device->ciErr));
+                       device->opencl_error(message);
+                       return 0;
+               }
+
+               return *size_buffer.get_data();
+       }
+
        virtual bool enqueue_split_kernel_data_init(const KernelDimensions& dim,
                                                    RenderTile& rtile,
                                                    int num_global_elements,
@@ -298,7 +341,7 @@ public:
                return make_int2(64, 1);
        }
 
-       virtual int2 split_kernel_global_size(DeviceTask *task)
+       virtual int2 split_kernel_global_size(device_memory& kg, device_memory& data, DeviceTask */*task*/)
        {
                size_t max_buffer_size;
                clGetDeviceInfo(device->cdDevice, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &max_buffer_size, NULL);
@@ -306,7 +349,7 @@ public:
                        << string_human_readable_number(max_buffer_size) << " bytes. ("
                        << string_human_readable_size(max_buffer_size) << ").";
 
-               size_t num_elements = max_elements_for_max_buffer_size(max_buffer_size / 2, task->passes_size);
+               size_t num_elements = max_elements_for_max_buffer_size(kg, data, max_buffer_size / 2);
                int2 global_size = make_int2(round_down((int)sqrt(num_elements), 64), (int)sqrt(num_elements));
                VLOG(1) << "Global size: " << global_size << ".";
                return global_size;