COMPMID-515: L2 Pooling for FP32/FP16 in CL.

Change-Id: I43641fa672f5905ca62edd1f63fc93e0cf7ea382
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/85963
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
index fdcbd5a..b97564e 100644
--- a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
@@ -29,6 +29,7 @@
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/ITensor.h"
 #include "arm_compute/core/NEON/NEFixedPoint.h"
+#include "arm_compute/core/NEON/NEMath.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
@@ -36,6 +37,7 @@
 
 #include <algorithm>
 #include <arm_neon.h>
+#include <cmath>
 #include <limits>
 #include <set>
 #include <string>
@@ -111,6 +113,7 @@
 
     ARM_COMPUTE_ERROR_ON_NULLPTR(output);
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON(pool_type == PoolingType::L2 && is_data_type_fixed_point(input->info()->data_type()));
     ARM_COMPUTE_ERROR_ON(supported_pool_sizes.find(pool_size) == supported_pool_sizes.end());
     ARM_COMPUTE_ERROR_ON(7 == pool_size && input->info()->data_type() != DataType::F32);
     ARM_COMPUTE_ERROR_ON(pool_pad_x >= pool_size || pool_pad_y >= pool_size);
@@ -235,41 +238,146 @@
         case 2:
             if(input->info()->data_type() == DataType::QS8)
             {
-                _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling2_q8<PoolingType::AVG> : &NEPoolingLayerKernel::pooling2_q8<PoolingType::MAX>;
+                switch(pool_type)
+                {
+                    case PoolingType::AVG:
+                        _func = &NEPoolingLayerKernel::pooling2_q8<PoolingType::AVG>;
+                        break;
+                    case PoolingType::MAX:
+                        _func = &NEPoolingLayerKernel::pooling2_q8<PoolingType::MAX>;
+                        break;
+                    default:
+                        ARM_COMPUTE_ERROR("Unsupported pooling type!");
+                }
             }
             else if(input->info()->data_type() == DataType::QS16)
             {
-                _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling2_q16<PoolingType::AVG> : &NEPoolingLayerKernel::pooling2_q16<PoolingType::MAX>;
+                switch(pool_type)
+                {
+                    case PoolingType::AVG:
+                        _func = &NEPoolingLayerKernel::pooling2_q16<PoolingType::AVG>;
+                        break;
+                    case PoolingType::MAX:
+                        _func = &NEPoolingLayerKernel::pooling2_q16<PoolingType::MAX>;
+                        break;
+                    default:
+                        ARM_COMPUTE_ERROR("Unsupported pooling type!");
+                }
             }
             else if(input->info()->data_type() == DataType::F16)
             {
-                _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling2_f16<PoolingType::AVG> : &NEPoolingLayerKernel::pooling2_f16<PoolingType::MAX>;
+                switch(pool_type)
+                {
+                    case PoolingType::AVG:
+                        _func = &NEPoolingLayerKernel::pooling2_f16<PoolingType::AVG>;
+                        break;
+                    case PoolingType::L2:
+                        _func = &NEPoolingLayerKernel::pooling2_f16<PoolingType::L2>;
+                        break;
+                    case PoolingType::MAX:
+                        _func = &NEPoolingLayerKernel::pooling2_f16<PoolingType::MAX>;
+                        break;
+                    default:
+                        ARM_COMPUTE_ERROR("Unsupported pooling type!");
+                }
             }
             else if(input->info()->data_type() == DataType::F32)
             {
-                _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling2_f32<PoolingType::AVG> : &NEPoolingLayerKernel::pooling2_f32<PoolingType::MAX>;
+                switch(pool_type)
+                {
+                    case PoolingType::AVG:
+                        _func = &NEPoolingLayerKernel::pooling2_f32<PoolingType::AVG>;
+                        break;
+                    case PoolingType::L2:
+                        _func = &NEPoolingLayerKernel::pooling2_f32<PoolingType::L2>;
+                        break;
+                    case PoolingType::MAX:
+                        _func = &NEPoolingLayerKernel::pooling2_f32<PoolingType::MAX>;
+                        break;
+                    default:
+                        ARM_COMPUTE_ERROR("Unsupported pooling type!");
+                }
             }
             break;
         case 3:
             if(input->info()->data_type() == DataType::QS8)
             {
-                _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling3_q8<PoolingType::AVG> : &NEPoolingLayerKernel::pooling3_q8<PoolingType::MAX>;
+                switch(pool_type)
+                {
+                    case PoolingType::AVG:
+                        _func = &NEPoolingLayerKernel::pooling3_q8<PoolingType::AVG>;
+                        break;
+                    case PoolingType::MAX:
+                        _func = &NEPoolingLayerKernel::pooling3_q8<PoolingType::MAX>;
+                        break;
+                    default:
+                        ARM_COMPUTE_ERROR("Unsupported pooling type!");
+                }
             }
             else if(input->info()->data_type() == DataType::QS16)
             {
-                _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling3_q16<PoolingType::AVG> : &NEPoolingLayerKernel::pooling3_q16<PoolingType::MAX>;
+                switch(pool_type)
+                {
+                    case PoolingType::AVG:
+                        _func = &NEPoolingLayerKernel::pooling3_q16<PoolingType::AVG>;
+                        break;
+                    case PoolingType::MAX:
+                        _func = &NEPoolingLayerKernel::pooling3_q16<PoolingType::MAX>;
+                        break;
+                    default:
+                        ARM_COMPUTE_ERROR("Unsupported pooling type!");
+                }
             }
             else if(input->info()->data_type() == DataType::F16)
             {
-                _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling3_f16<PoolingType::AVG> : &NEPoolingLayerKernel::pooling3_f16<PoolingType::MAX>;
+                switch(pool_type)
+                {
+                    case PoolingType::AVG:
+                        _func = &NEPoolingLayerKernel::pooling3_f16<PoolingType::AVG>;
+                        break;
+                    case PoolingType::L2:
+                        _func = &NEPoolingLayerKernel::pooling3_f16<PoolingType::L2>;
+                        break;
+                    case PoolingType::MAX:
+                        _func = &NEPoolingLayerKernel::pooling3_f16<PoolingType::MAX>;
+                        break;
+                    default:
+                        ARM_COMPUTE_ERROR("Unsupported pooling type!");
+                }
             }
             else if(input->info()->data_type() == DataType::F32)
             {
-                _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling3_f32<PoolingType::AVG> : &NEPoolingLayerKernel::pooling3_f32<PoolingType::MAX>;
+                switch(pool_type)
+                {
+                    case PoolingType::AVG:
+                        _func = &NEPoolingLayerKernel::pooling3_f32<PoolingType::AVG>;
+                        break;
+                    case PoolingType::L2:
+                        _func = &NEPoolingLayerKernel::pooling3_f32<PoolingType::L2>;
+                        break;
+                    case PoolingType::MAX:
+                        _func = &NEPoolingLayerKernel::pooling3_f32<PoolingType::MAX>;
+                        break;
+                    default:
+                        ARM_COMPUTE_ERROR("Unsupported pooling type!");
+                }
             }
             break;
         case 7:
-            _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling7_f32<PoolingType::AVG> : &NEPoolingLayerKernel::pooling7_f32<PoolingType::MAX>;
+            switch(pool_type)
+            {
+                case PoolingType::AVG:
+                    _func = &NEPoolingLayerKernel::pooling7_f32<PoolingType::AVG>;
+                    break;
+                case PoolingType::L2:
+                    _func = &NEPoolingLayerKernel::pooling7_f32<PoolingType::L2>;
+                    break;
+                case PoolingType::MAX:
+                    _func = &NEPoolingLayerKernel::pooling7_f32<PoolingType::MAX>;
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Unsupported pooling type!");
+            }
             break;
         default:
             ARM_COMPUTE_ERROR("Unsupported pooling size");
@@ -436,11 +544,20 @@
 
     execute_window_loop(window, [&](const Coordinates & id)
     {
-        const float16x4_t top_data    = vld1_f16(reinterpret_cast<const float16_t *>(input_top_ptr + input.offset()));
-        const float16x4_t middle_data = vld1_f16(reinterpret_cast<const float16_t *>(input_middle_ptr + input.offset()));
-        const float16x4_t bottom_data = vld1_f16(reinterpret_cast<const float16_t *>(input_bottom_ptr + input.offset()));
-        float16x4_t       res         = {};
-        if(pooling_type == PoolingType::AVG)
+        float16x4_t top_data    = vld1_f16(reinterpret_cast<const float16_t *>(input_top_ptr + input.offset()));
+        float16x4_t middle_data = vld1_f16(reinterpret_cast<const float16_t *>(input_middle_ptr + input.offset()));
+        float16x4_t bottom_data = vld1_f16(reinterpret_cast<const float16_t *>(input_bottom_ptr + input.offset()));
+        float16x4_t res         = {};
+
+        // Get power of 2 in case of l2 pooling
+        if(pooling_type == PoolingType::L2)
+        {
+            top_data    = vmul_f16(top_data, top_data);
+            middle_data = vmul_f16(middle_data, middle_data);
+            bottom_data = vmul_f16(bottom_data, bottom_data);
+        }
+
+        if(pooling_type != PoolingType::MAX)
         {
             // Calculate scale
             const float       scale   = calculate_avg_scale(id, pool_size, upper_bound_w, upper_bound_h, pool_pad_x, pool_pad_y, pool_stride_x, pool_stride_y);
@@ -456,6 +573,13 @@
             res                        = vpmax_f16(vset_lane_f16(-std::numeric_limits<float>::max(), max_data, 3), max_data);
             res                        = vpmax_f16(res, res);
         }
+
+        // Calculate square-root in case of l2 pooling
+        if(pooling_type == PoolingType::L2)
+        {
+            res = vinv_f16(vinvsqrt_f16(res));
+        }
+
         *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(res, 0);
     },
     input, output);
@@ -484,11 +608,20 @@
 
     execute_window_loop(window, [&](const Coordinates & id)
     {
-        const auto  top_data    = vld2q_f16(reinterpret_cast<const float16_t *>(input_top_ptr + input.offset()));
-        const auto  bottom_data = vld2q_f16(reinterpret_cast<const float16_t *>(input_bottom_ptr + input.offset()));
+        auto        top_data    = vld2q_f16(reinterpret_cast<const float16_t *>(input_top_ptr + input.offset()));
+        auto        bottom_data = vld2q_f16(reinterpret_cast<const float16_t *>(input_bottom_ptr + input.offset()));
         float16x8_t res         = {};
 
-        if(pooling_type == PoolingType::AVG)
+        // Get power of 2 in case of l2 pooling
+        if(pooling_type == PoolingType::L2)
+        {
+            top_data.val[0]    = vmulq_f16(top_data.val[0], top_data.val[0]);
+            top_data.val[1]    = vmulq_f16(top_data.val[1], top_data.val[1]);
+            bottom_data.val[0] = vmulq_f16(bottom_data.val[0], bottom_data.val[0]);
+            bottom_data.val[1] = vmulq_f16(bottom_data.val[1], bottom_data.val[1]);
+        }
+
+        if(pooling_type != PoolingType::MAX)
         {
             const float       scale   = calculate_avg_scale(id, pool_size, upper_bound_w, upper_bound_h, pool_pad_x, pool_pad_y, pool_stride_x, pool_stride_y);
             const float16x8_t scale_v = vdupq_n_f16(scale);
@@ -498,6 +631,14 @@
         {
             res = vmaxq_f16(bottom_data.val[1], vmaxq_f16(bottom_data.val[0], vmaxq_f16(top_data.val[0], top_data.val[1])));
         }
+
+        // Calculate square-root in case of l2 pooling
+        if(pooling_type == PoolingType::L2)
+        {
+            res = vinvq_f16(vinvsqrtq_f16(res));
+        }
+
+        // Store result
         vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), res);
     },
     input, output);
@@ -529,10 +670,19 @@
 
     execute_window_loop(window, [&](const Coordinates & id)
     {
-        const float32x2_t top_data    = vld1_f32(reinterpret_cast<const float *>(input_top_ptr + input.offset()));
-        const float32x2_t bottom_data = vld1_f32(reinterpret_cast<const float *>(input_bottom_ptr + input.offset()));
-        float32x2_t       res         = {};
-        if(pooling_type == PoolingType::AVG)
+        float32x2_t top_data    = vld1_f32(reinterpret_cast<const float *>(input_top_ptr + input.offset()));
+        float32x2_t bottom_data = vld1_f32(reinterpret_cast<const float *>(input_bottom_ptr + input.offset()));
+        float32x2_t res         = {};
+        float       final_res   = 0;
+
+        // Get power of 2 in case of l2 pooling
+        if(pooling_type == PoolingType::L2)
+        {
+            top_data    = vmul_f32(top_data, top_data);
+            bottom_data = vmul_f32(bottom_data, bottom_data);
+        }
+
+        if(pooling_type != PoolingType::MAX)
         {
             // Calculate scale
             float             scale   = calculate_avg_scale(id, pool_size, upper_bound_w, upper_bound_h, pool_pad_x, pool_pad_y, pool_stride_x, pool_stride_y);
@@ -547,7 +697,16 @@
             const float32x2_t max_data = vmax_f32(top_data, bottom_data);
             res                        = vpmax_f32(max_data, max_data);
         }
-        *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(res, 0);
+        final_res = vget_lane_f32(res, 0);
+
+        // Calculate square-root in case of l2 pooling
+        if(pooling_type == PoolingType::L2)
+        {
+            final_res = sqrt(final_res);
+        }
+
+        // Store result
+        *(reinterpret_cast<float *>(output.ptr())) = final_res;
     },
     input, output);
 }
@@ -719,11 +878,21 @@
 
     execute_window_loop(window, [&](const Coordinates & id)
     {
-        const float32x4_t top_data    = vld1q_f32(reinterpret_cast<const float *>(input_top_ptr + input.offset()));
-        const float32x4_t middle_data = vld1q_f32(reinterpret_cast<const float *>(input_middle_ptr + input.offset()));
-        const float32x4_t bottom_data = vld1q_f32(reinterpret_cast<const float *>(input_bottom_ptr + input.offset()));
-        float32x2_t       res         = {};
-        if(pooling_type == PoolingType::AVG)
+        float32x4_t top_data    = vld1q_f32(reinterpret_cast<const float *>(input_top_ptr + input.offset()));
+        float32x4_t middle_data = vld1q_f32(reinterpret_cast<const float *>(input_middle_ptr + input.offset()));
+        float32x4_t bottom_data = vld1q_f32(reinterpret_cast<const float *>(input_bottom_ptr + input.offset()));
+        float32x2_t res         = {};
+        float       final_res   = 0;
+
+        // Get power of 2 in case of l2 pooling
+        if(pooling_type == PoolingType::L2)
+        {
+            top_data    = vmulq_f32(top_data, top_data);
+            middle_data = vmulq_f32(middle_data, middle_data);
+            bottom_data = vmulq_f32(bottom_data, bottom_data);
+        }
+
+        if(pooling_type != PoolingType::MAX)
         {
             // Calculate scale
             float             scale   = calculate_avg_scale(id, pool_size, upper_bound_w, upper_bound_h, pool_pad_x, pool_pad_y, pool_stride_x, pool_stride_y);
@@ -740,7 +909,16 @@
             res                        = vpmax_f32(vget_high_f32(vsetq_lane_f32(-std::numeric_limits<float>::max(), max_data, 3)), vget_low_f32(max_data));
             res                        = vpmax_f32(res, res);
         }
-        *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(res, 0);
+        final_res = vget_lane_f32(res, 0);
+
+        // Calculate square-root in case of l2 pooling
+        if(pooling_type == PoolingType::L2)
+        {
+            final_res = sqrt(final_res);
+        }
+
+        // Store result
+        *(reinterpret_cast<float *>(output.ptr())) = final_res;
     },
     input, output);
 }
@@ -769,19 +947,32 @@
 
     execute_window_loop(window, [&](const Coordinates & id)
     {
-        float32x2_t res = {};
-        if(pooling_type == PoolingType::AVG)
+        float32x2_t res       = {};
+        float       final_res = 0.f;
+        if(pooling_type != PoolingType::MAX)
         {
             // Calculate scale
             float             scale   = calculate_avg_scale(id, pool_size, upper_bound_w, upper_bound_h, pool_pad_x, pool_pad_y, pool_stride_x, pool_stride_y);
             const float32x2_t scale_v = vdup_n_f32(scale);
 
             // Perform pooling
-            float32x4x2_t data     = vld2q_f32(reinterpret_cast<const float *>(input_ptrs[0] + input.offset()));
-            float32x4_t   sum_data = vaddq_f32(data.val[0], vsetq_lane_f32(0.f, data.val[1], 3));
+            float32x4x2_t data = vld2q_f32(reinterpret_cast<const float *>(input_ptrs[0] + input.offset()));
+            // Get power of 2 in case of l2 pooling
+            if(pooling_type == PoolingType::L2)
+            {
+                data.val[0] = vmulq_f32(data.val[0], data.val[0]);
+                data.val[1] = vmulq_f32(data.val[1], data.val[1]);
+            }
+            float32x4_t sum_data = vaddq_f32(data.val[0], vsetq_lane_f32(0.f, data.val[1], 3));
             for(int i = 1; i < pool_size; ++i)
             {
-                data     = vld2q_f32(reinterpret_cast<const float *>(input_ptrs[i] + input.offset()));
+                data = vld2q_f32(reinterpret_cast<const float *>(input_ptrs[i] + input.offset()));
+                // Get power of 2 in case of l2 pooling
+                if(pooling_type == PoolingType::L2)
+                {
+                    data.val[0] = vmulq_f32(data.val[0], data.val[0]);
+                    data.val[1] = vmulq_f32(data.val[1], data.val[1]);
+                }
                 sum_data = vaddq_f32(sum_data, data.val[0]);
                 sum_data = vaddq_f32(sum_data, vsetq_lane_f32(0.f, data.val[1], 3));
             }
@@ -800,7 +991,16 @@
             res = vpmax_f32(res, vpmax_f32(vget_high_f32(max_data.val[0]), vget_low_f32(max_data.val[0])));
             res = vpmax_f32(res, res);
         }
-        *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(res, 0);
+        final_res = vget_lane_f32(res, 0);
+
+        // Calculate square-root in case of l2 pooling
+        if(pooling_type == PoolingType::L2)
+        {
+            final_res = sqrt(final_res);
+        }
+
+        // Store result
+        *(reinterpret_cast<float *>(output.ptr())) = final_res;
     },
     input, output);
 }