Cycles: SSE optimization for line segments/ribbons hair
authorSv. Lockal <lockalsash@gmail.com>
Sat, 22 Mar 2014 20:45:48 +0000 (00:45 +0400)
committerSv. Lockal <lockalsash@gmail.com>
Sat, 22 Mar 2014 20:45:59 +0000 (00:45 +0400)
Gives ~11% speedup for hair.blend, ~10% for koro_final.blend

Also extract few common subexpressions in hair calculation.

Reviewed By: brecht

Differential Revision: https://developer.blender.org/D318

intern/cycles/kernel/kernel_bvh.h
intern/cycles/util/util_simd.h

index 17791f4f35abae0c4adddb42a542ca3a367a5ae9..942c7abce65f9e31bcf0631c3d8d599d25775460 100644 (file)
@@ -596,6 +596,13 @@ ccl_device_inline bool bvh_cardinal_curve_intersect(KernelGlobals *kg, Intersect
 ccl_device_inline bool bvh_curve_intersect(KernelGlobals *kg, Intersection *isect,
        float3 P, float3 idir, uint visibility, int object, int curveAddr, int segment, uint *lcg_state, float difl, float extmax)
 {
+       /* define few macros to minimize code duplication for SSE */
+#ifndef __KERNEL_SSE2__
+#define len3_squared(x) len_squared(x)
+#define len3(x) len(x)
+#define dot3(x, y) dot(x, y)
+#endif
+
        /* curve Intersection check */
        int flags = kernel_data.curve.curveflags;
 
@@ -606,6 +613,7 @@ ccl_device_inline bool bvh_curve_intersect(KernelGlobals *kg, Intersection *isec
        int k0 = cnum + segment;
        int k1 = k0 + 1;
 
+#ifndef __KERNEL_SSE2__
        float4 P1 = kernel_tex_fetch(__curve_keys, k0);
        float4 P2 = kernel_tex_fetch(__curve_keys, k1);
 
@@ -617,36 +625,72 @@ ccl_device_inline bool bvh_curve_intersect(KernelGlobals *kg, Intersection *isec
        /* minimum width extension */
        float r1 = or1;
        float r2 = or2;
+       float3 dif = P - p1;
+       float3 dif_second = P - p2;
        if(difl != 0.0f) {
-               float pixelsize = min(len(p1 - P) * difl, extmax);
+               float pixelsize = min(len3(dif) * difl, extmax);
                r1 = or1 < pixelsize ? pixelsize : or1;
-               pixelsize = min(len(p2 - P) * difl, extmax);
+               pixelsize = min(len3(dif_second) * difl, extmax);
                r2 = or2 < pixelsize ? pixelsize : or2;
        }
        /* --- */
 
-       float mr = max(r1,r2);
-       float3 dif = P - p1;
-       float3 dir = 1.0f/idir;
-       float l = len(p2 - p1);
+       float3 dir = 1.0f / idir;
+       float3 p21_diff = p2 - p1;
+       float3 sphere_dif1 = (dif + dif_second) * 0.5f;
+       float sphere_b_tmp = dot3(dir, sphere_dif1);
+       float3 sphere_dif2 = sphere_dif1 - sphere_b_tmp * dir;
+#else
+       const __m128 p1 = _mm_load_ps(&kg->__curve_keys.data[k0].x);
+       const __m128 p2 = _mm_load_ps(&kg->__curve_keys.data[k1].x);
+       const __m128 or12 = shuffle<3, 3, 3, 3>(p1, p2);
+
+       __m128 r12 = or12;
+       const __m128 vP = load_m128(P);
+       const __m128 dif = _mm_sub_ps(vP, p1);
+       const __m128 dif_second = _mm_sub_ps(vP, p2);
+       if(difl != 0.0f) {
+               const __m128 len1_sq = len3_squared_splat(dif);
+               const __m128 len2_sq = len3_squared_splat(dif_second);
+               const __m128 len12 = _mm_sqrt_ps(shuffle<0, 0, 0, 0>(len1_sq, len2_sq));
+               const __m128 pixelsize12 = _mm_min_ps(_mm_mul_ps(len12, _mm_set1_ps(difl)), _mm_set1_ps(extmax));
+               r12 = _mm_max_ps(or12, pixelsize12);
+       }
+       float or1 = _mm_cvtss_f32(or12), or2 = _mm_cvtss_f32(broadcast<2>(or12));
+       float r1 = _mm_cvtss_f32(r12), r2 = _mm_cvtss_f32(broadcast<2>(r12));
+
+       const __m128 dir = _mm_div_ps(_mm_set1_ps(1.0f), load_m128(idir));
+       const __m128 p21_diff = _mm_sub_ps(p2, p1);
+       const __m128 sphere_dif1 = _mm_mul_ps(_mm_add_ps(dif, dif_second), _mm_set1_ps(0.5f));
+       const __m128 sphere_b_tmp = dot3_splat(dir, sphere_dif1);
+       const __m128 sphere_dif2 = fnma(sphere_b_tmp, dir, sphere_dif1);
+#endif
 
+       float mr = max(r1, r2);
+       float l = len3(p21_diff);
+       float invl = 1.0f / l;
        float sp_r = mr + 0.5f * l;
-       float3 sphere_dif = P - ((p1 + p2) * 0.5f);
-       float sphere_b = dot(dir,sphere_dif);
-       sphere_dif = sphere_dif - sphere_b * dir;
-       sphere_b = dot(dir,sphere_dif);
-       float sdisc = sphere_b * sphere_b - len_squared(sphere_dif) + sp_r * sp_r;
+
+       float sphere_b = dot3(dir, sphere_dif2);
+       float sdisc = sphere_b * sphere_b - len3_squared(sphere_dif2) + sp_r * sp_r;
+
        if(sdisc < 0.0f)
                return false;
 
        /* obtain parameters and test midpoint distance for suitable modes */
-       float3 tg = (p2 - p1) / l;
-       float gd = (r2 - r1) / l;
-       float dirz = dot(dir,tg);
-       float difz = dot(dif,tg);
+#ifndef __KERNEL_SSE2__
+       float3 tg = p21_diff * invl;
+#else
+       const __m128 tg = _mm_mul_ps(p21_diff, _mm_set1_ps(invl));
+#endif
+       float gd = (r2 - r1) * invl;
+
+       float dirz = dot3(dir, tg);
+       float difz = dot3(dif, tg);
 
        float a = 1.0f - (dirz*dirz*(1 + gd*gd));
-       float halfb = dot(dir,dif) - dirz*(difz + gd*(difz*gd + r1));
+
+       float halfb = dot3(dir, dif) - dirz*(difz + gd*(difz*gd + r1));
 
        float tcentre = -halfb/a;
        float zcentre = difz + (dirz * tcentre);
@@ -657,11 +701,15 @@ ccl_device_inline bool bvh_curve_intersect(KernelGlobals *kg, Intersection *isec
                return false;
 
        /* test minimum separation */
+#ifndef __KERNEL_SSE2__
        float3 cprod = cross(tg, dir);
-       float3 cprod2 = cross(tg, dif);
-       float cprodsq = len_squared(cprod);
-       float cprod2sq = len_squared(cprod2);
-       float distscaled = dot(cprod,dif);
+       float cprod2sq = len3_squared(cross(tg, dif));
+#else
+       const __m128 cprod = cross(tg, dir);
+       float cprod2sq = len3_squared(cross_zxy(tg, dif));
+#endif
+       float cprodsq = len3_squared(cprod);
+       float distscaled = dot3(cprod, dif);
 
        if(cprodsq == 0)
                distscaled = cprod2sq;
@@ -672,10 +720,15 @@ ccl_device_inline bool bvh_curve_intersect(KernelGlobals *kg, Intersection *isec
                return false;
 
        /* calculate true intersection */
-       float3 tdif = P - p1 + tcentre * dir;
-       float tdifz = dot(tdif,tg);
-       float tb = 2*(dot(dir,tdif) - dirz*(tdifz + gd*(tdifz*gd + r1)));
-       float tc = dot(tdif,tdif) - tdifz * tdifz * (1 + gd*gd) - r1*r1 - 2*r1*tdifz*gd;
+#ifndef __KERNEL_SSE2__
+       float3 tdif = dif + tcentre * dir;
+#else
+       const __m128 tdif = fma(_mm_set1_ps(tcentre), dir, dif);
+#endif
+       float tdifz = dot3(tdif, tg);
+       float tdifma = tdifz*gd + r1;
+       float tb = 2*(dot3(dir, tdif) - dirz*(tdifz + gd*tdifma));
+       float tc = dot3(tdif, tdif) - tdifz*tdifz - tdifma*tdifma;
        float td = tb*tb - 4*a*tc;
 
        if (td < 0.0f)
@@ -709,7 +762,7 @@ ccl_device_inline bool bvh_curve_intersect(KernelGlobals *kg, Intersection *isec
                }
 
                /* stochastic fade from minimum width */
-               float adjradius = or1 + z * (or2 - or1) l;
+               float adjradius = or1 + z * (or2 - or1) * invl;
                adjradius = adjradius / (r1 + z * gd);
                if(lcg_state && adjradius != 1.0f) {
                        if(lcg_step_float(lcg_state) > adjradius)
@@ -721,9 +774,9 @@ ccl_device_inline bool bvh_curve_intersect(KernelGlobals *kg, Intersection *isec
 
                        if (flags & CURVE_KN_ENCLOSEFILTER) {
                                float enc_ratio = 1.01f;
-                               if((dot(P - p1, tg) > -r1 * enc_ratio) && (dot(P - p2, tg) < r2 * enc_ratio)) {
+                               if((difz > -r1 * enc_ratio) && (dot3(dif_second, tg) < r2 * enc_ratio)) {
                                        float a2 = 1.0f - (dirz*dirz*(1 + gd*gd*enc_ratio*enc_ratio));
-                                       float c2 = dot(dif,dif) - difz * difz * (1 + gd*gd*enc_ratio*enc_ratio) - r1*r1*enc_ratio*enc_ratio - 2*r1*difz*gd*enc_ratio;
+                                       float c2 = dot3(dif, dif) - difz * difz * (1 + gd*gd*enc_ratio*enc_ratio) - r1*r1*enc_ratio*enc_ratio - 2*r1*difz*gd*enc_ratio;
                                        if(a2*c2 < 0.0f)
                                                return false;
                                }
@@ -739,7 +792,7 @@ ccl_device_inline bool bvh_curve_intersect(KernelGlobals *kg, Intersection *isec
                                isect->prim = curveAddr;
                                isect->segment = segment;
                                isect->object = object;
-                               isect->u = z/l;
+                               isect->u = z*invl;
                                isect->v = td/(4*a*a);
                                /*isect->v = 1.0f - adjradius;*/
                                isect->t = t;
@@ -753,6 +806,12 @@ ccl_device_inline bool bvh_curve_intersect(KernelGlobals *kg, Intersection *isec
        }
 
        return false;
+
+#ifndef __KERNEL_SSE2__
+#undef len3_squared
+#undef len3
+#undef dot3
+#endif
 }
 #endif
 
index fff682bb436c67bfa2341290a013e6ca5d27df8e..486816cc5c0fb3b2cbf66ae033e655977a33044a 100644 (file)
@@ -154,6 +154,12 @@ ccl_device_inline const __m128 fms(const __m128& a, const __m128& b, const __m12
        return _mm_sub_ps(_mm_mul_ps(a, b), c);
 }
 
+/* calculate -a*b+c (replacement for fused negated-multiply-subtract on SSE CPUs) */
+ccl_device_inline const __m128 fnma(const __m128& a, const __m128& b, const __m128& c)
+{
+       return _mm_sub_ps(c, _mm_mul_ps(a, b));
+}
+
 template<size_t N> ccl_device_inline const __m128 broadcast(const __m128& a)
 {
        return _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(a), _MM_SHUFFLE(N, N, N, N)));
@@ -204,6 +210,52 @@ ccl_device_inline const __m128 load_m128(const float3 &vec)
 }
 #endif /* __KERNEL_WITH_SSE_ALIGN__ */
 
+ccl_device_inline const __m128 dot3_splat(const __m128& a, const __m128& b)
+{
+#ifdef __KERNEL_SSE41__
+       return _mm_dp_ps(a, b, 0x7f);
+#else
+       __m128 t = _mm_mul_ps(a, b);
+       return _mm_set1_ps(((float*)&t)[0] + ((float*)&t)[1] + ((float*)&t)[2]);
+#endif
+}
+
+ccl_device_inline float dot3(const __m128& a, const __m128& b)
+{
+#ifdef __KERNEL_SSE41__
+       return _mm_cvtss_f32(_mm_dp_ps(a, b, 0x7f));
+#else
+       __m128 t = _mm_mul_ps(a, b);
+       return ((float*)&t)[0] + ((float*)&t)[1] + ((float*)&t)[2];
+#endif
+}
+
+ccl_device_inline const __m128 len3_squared_splat(const __m128& a)
+{
+       return dot3_splat(a, a);
+}
+
+ccl_device_inline float len3_squared(const __m128& a)
+{
+       return dot3(a, a);
+}
+
+ccl_device_inline float len3(const __m128& a)
+{
+       return _mm_cvtss_f32(_mm_sqrt_ss(dot3_splat(a, a)));
+}
+
+/* calculate shuffled cross product, useful when order of components does not matter */
+ccl_device_inline const __m128 cross_zxy(const __m128& a, const __m128& b)
+{
+       return fms(a, shuffle<1, 2, 0, 3>(b), _mm_mul_ps(b, shuffle<1, 2, 0, 3>(a)));
+}
+
+ccl_device_inline const __m128 cross(const __m128& a, const __m128& b)
+{
+       return shuffle<1, 2, 0, 3>(cross_zxy(a, b));
+}
+
 #endif /* __KERNEL_SSE2__ */
 
 CCL_NAMESPACE_END