COMPMID-3968 30% regression on FSSD v1 25 Grayscale

Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Change-Id: Ib1ecd7aa10fec0b7e2b3d929e212c1af34c0f58d
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4533
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NEDepthwiseConvolutionLayerNativeKernel.cpp b/src/core/NEON/kernels/NEDepthwiseConvolutionLayerNativeKernel.cpp
index 8731590..23b9bc5 100644
--- a/src/core/NEON/kernels/NEDepthwiseConvolutionLayerNativeKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthwiseConvolutionLayerNativeKernel.cpp
@@ -579,6 +579,144 @@
     input_it, weights_it, biases_it, output_it);
 }
 
+template <typename T, typename TW>
+void depthwise_loop_pow2_quantized_per_tensor(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info,
+                                              const Size2D &dilation, unsigned int depth_multiplier, std::vector<int> output_multiplier, std::vector<int> output_shift, const Window &window, bool has_biases)
+{
+    constexpr int half_vec = vector_size / 2;
+
+    using AccType          = int32_t;
+    using AccVectorType    = typename wrapper::traits::neon_vector<AccType, half_vec>::type;
+    using AccVectorTagType = typename wrapper::traits::neon_vector<AccType, half_vec>::tag_type;
+    using TagType          = typename wrapper::traits::neon_vector<T, vector_size>::tag_type;
+
+    const auto run_info = DepthwiseConvolutionRunInfo(*input->info(), *weights->info(), conv_info, window, depth_multiplier);
+
+    const auto input_qoffset_vec   = wrapper::vreinterpret(wrapper::vmovl(wrapper::vdup_n(static_cast<T>(input->info()->quantization_info().uniform().offset), TagType{})));
+    const auto weights_qoffset_vec = wrapper::vreinterpret(wrapper::vmovl(wrapper::vdup_n(static_cast<TW>(weights->info()->quantization_info().uniform().offset), TagType{})));
+    const auto output_qoffset_vec  = wrapper::vdup_n(output->info()->quantization_info().uniform().offset, arm_compute::wrapper::traits::vector_128_tag{});
+
+    const auto lower = wrapper::vdup_n(static_cast<AccType>(std::numeric_limits<T>::lowest()), AccVectorTagType{});
+    const auto upper = wrapper::vdup_n(static_cast<AccType>(std::numeric_limits<T>::max()), AccVectorTagType{});
+    const auto zero  = wrapper::vdup_n(static_cast<AccType>(0), AccVectorTagType{});
+
+    const auto out_mul   = output_multiplier.at(0);
+    const auto out_shift = output_shift.at(0);
+
+    Window execution_window = window;
+    execution_window.set(Window::DimX, Window::Dimension(0, run_info.input_depth, 1));
+
+    Window win_input = execution_window;
+    win_input.set(Window::DimY, dim_manual_loop);
+    win_input.set(Window::DimZ, dim_manual_loop);
+
+    Window win_weights = window;
+    win_weights.set_dimension_step(Window::DimX, run_info.x_step);
+    win_weights.set(Window::DimY, dim_manual_loop);
+    win_weights.set(Window::DimZ, dim_manual_loop);
+    win_weights.set(Window::DimW, dim_manual_loop);
+
+    Window win_output = window;
+    win_output.set_dimension_step(Window::DimX, run_info.x_step);
+
+    Iterator input_it(input, win_input);
+    Iterator weights_it(weights, win_weights);
+    Iterator output_it(output, win_output);
+    Iterator biases_it{};
+
+    if(has_biases)
+    {
+        biases_it = Iterator(biases, win_weights);
+    }
+
+    std::vector<AccVectorType> acc0(depth_multiplier / vector_size);
+    std::vector<AccVectorType> acc1(depth_multiplier / vector_size);
+
+    execute_window_loop(execution_window, [&](const Coordinates & id)
+    {
+        std::fill(begin(acc0), end(acc0), zero);
+        std::fill(begin(acc1), end(acc1), zero);
+
+        const int32_t input_y      = id.y() * run_info.conv_stride_x - run_info.conv_pad_left;
+        const int32_t input_z      = id.z() * run_info.conv_stride_y - run_info.conv_pad_top;
+        int64_t       input_offset = input_y * run_info.input_stride_y + input_z * run_info.input_stride_z;
+
+        auto weights_ptr = weights_it.ptr();
+        for(size_t h = 0; h < run_info.weights_height; ++h)
+        {
+            const int32_t current_h = input_z + h * dilation.y();
+            if(current_h >= 0 && current_h < static_cast<int32_t>(run_info.input_height))
+            {
+                int offs = input_offset;
+                for(size_t w = 0; w < run_info.weights_width; ++w)
+                {
+                    const int32_t current_w = input_y + w * dilation.x();
+                    if(current_w >= 0 && current_w < static_cast<int32_t>(run_info.input_width))
+                    {
+                        const auto input_8x8     = wrapper::vdup_n(*(reinterpret_cast<T *>(input_it.ptr() + std::min(static_cast<size_t>(offs), run_info.input_max_offset))), TagType{});
+                        const auto input_s16x8   = wrapper::vreinterpret(wrapper::vmovl(input_8x8));
+                        const auto input_no_offs = wrapper::vsub(input_s16x8, input_qoffset_vec);
+
+                        for(size_t m = 0, i = 0; m < depth_multiplier; m += vector_size, ++i)
+                        {
+                            const auto weights_8x8     = wrapper::vload(reinterpret_cast<TW *>(weights_ptr + m * sizeof(T) + w * run_info.weights_stride_y));
+                            const auto weights_s16x8   = wrapper::vreinterpret(wrapper::vmovl(weights_8x8));
+                            const auto weights_no_offs = wrapper::vsub(weights_s16x8, weights_qoffset_vec);
+
+                            acc0.at(i) = wrapper::vmlal(acc0.at(i), wrapper::vgetlow(input_no_offs), wrapper::vgetlow(weights_no_offs));
+                            acc1.at(i) = wrapper::vmlal(acc1.at(i), wrapper::vgethigh(input_no_offs), wrapper::vgethigh(weights_no_offs));
+                        }
+                    }
+
+                    offs += dilation.x() * run_info.input_stride_y;
+                }
+            }
+
+            weights_ptr += run_info.weights_stride_z;
+            input_offset += dilation.y() * run_info.input_stride_z;
+        }
+
+        for(size_t m = 0, i = 0; m < depth_multiplier; m += vector_size, ++i)
+        {
+            if(has_biases)
+            {
+                const auto bias_val0 = wrapper::vloadq(reinterpret_cast<int32_t *>(biases_it.ptr() + m * sizeof(int32_t)));
+                const auto bias_val1 = wrapper::vloadq(reinterpret_cast<int32_t *>(biases_it.ptr() + (m + half_vec) * sizeof(int32_t)));
+
+                acc0.at(i) = wrapper::vadd(acc0.at(i), bias_val0);
+                acc1.at(i) = wrapper::vadd(acc1.at(i), bias_val1);
+            }
+
+            if(out_shift < 0)
+            {
+                acc0.at(i) = wrapper::vadd(saturating_doubling_high_mul(acc0.at(i) * (1 << (-out_shift)), out_mul), output_qoffset_vec);
+                acc1.at(i) = wrapper::vadd(saturating_doubling_high_mul(acc1.at(i) * (1 << (-out_shift)), out_mul), output_qoffset_vec);
+            }
+            else
+            {
+                acc0.at(i) = wrapper::vadd(rounding_divide_by_exp2(saturating_doubling_high_mul(acc0.at(i), out_mul), out_shift), output_qoffset_vec);
+                acc1.at(i) = wrapper::vadd(rounding_divide_by_exp2(saturating_doubling_high_mul(acc1.at(i), out_mul), out_shift), output_qoffset_vec);
+            }
+
+            acc0.at(i) = wrapper::vmin(wrapper::vmax(acc0.at(i), lower), upper);
+            acc1.at(i) = wrapper::vmin(wrapper::vmax(acc1.at(i), lower), upper);
+
+            const auto out_val = wrapper::vcombine(wrapper::vmovn(acc0.at(i)),
+                                                   wrapper::vmovn(acc1.at(i)));
+
+            if(std::is_same<T, uint8_t>::value)
+            {
+                wrapper::vstore(reinterpret_cast<uint8_t *>(output_it.ptr() + m * sizeof(uint8_t)), wrapper::vqmovn(vreinterpretq_u16_s16(out_val)));
+            }
+            else
+            {
+                wrapper::vstore(reinterpret_cast<int8_t *>(output_it.ptr() + m * sizeof(int8_t)), wrapper::vqmovn(out_val));
+            }
+        }
+    },
+    input_it, weights_it, biases_it, output_it);
+}
+
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier,
                           const Size2D &dilation)
 {
@@ -761,7 +899,17 @@
     }
     else
     {
-        depthwise_loop_generic_quantized<T, TW>(_input, _weights, _biases, _output, _conv_info, _dilation, _depth_multiplier, _output_multiplier, _output_shift, window, has_biases);
+        const bool is_pow2                 = ((_depth_multiplier & (_depth_multiplier - 1)) == 0);
+        const bool is_quantized_per_tensor = !(is_data_type_quantized_per_channel(_weights->info()->data_type()));
+
+        if(is_pow2 && is_quantized_per_tensor && _depth_multiplier >= 8)
+        {
+            depthwise_loop_pow2_quantized_per_tensor<T, TW>(_input, _weights, _biases, _output, _conv_info, _dilation, _depth_multiplier, _output_multiplier, _output_shift, window, has_biases);
+        }
+        else
+        {
+            depthwise_loop_generic_quantized<T, TW>(_input, _weights, _biases, _output, _conv_info, _dilation, _depth_multiplier, _output_multiplier, _output_shift, window, has_biases);
+        }
     }
 }
 } // namespace arm_compute
diff --git a/src/core/NEON/wrapper/intrinsics/mla.h b/src/core/NEON/wrapper/intrinsics/mla.h
index 2b38b34..9fb5a08 100644
--- a/src/core/NEON/wrapper/intrinsics/mla.h
+++ b/src/core/NEON/wrapper/intrinsics/mla.h
@@ -66,6 +66,22 @@
 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 
 #undef VMLA_IMPL
+
+#define VMLAL_IMPL(vtype_in, vtype_out, postfix)                                     \
+    inline vtype_out vmlal(const vtype_out &a, const vtype_in &b, const vtype_in &c) \
+    {                                                                                \
+        return vmlal_##postfix(a, b, c);                                             \
+    }
+
+VMLAL_IMPL(uint8x8_t, uint16x8_t, u8)
+VMLAL_IMPL(int8x8_t, int16x8_t, s8)
+VMLAL_IMPL(uint16x4_t, uint32x4_t, u16)
+VMLAL_IMPL(int16x4_t, int32x4_t, s16)
+VMLAL_IMPL(uint32x2_t, uint64x2_t, u32)
+VMLAL_IMPL(int32x2_t, int64x2_t, s32)
+
+#undef VMLAL_IMPL
+
 } // namespace wrapper
 } // namespace arm_compute
 #endif /* ARM_COMPUTE_WRAPPER_MLA_H */
diff --git a/src/core/NEON/wrapper/intrinsics/reinterpret.h b/src/core/NEON/wrapper/intrinsics/reinterpret.h
index 0c26cd9..cf00a4a 100644
--- a/src/core/NEON/wrapper/intrinsics/reinterpret.h
+++ b/src/core/NEON/wrapper/intrinsics/reinterpret.h
@@ -42,7 +42,7 @@
     }
 
 VREINTERPRET_IMPL(int16x4_t, uint16x4_t, vreinterpret, s16, u16)
-
+VREINTERPRET_IMPL(int16x8_t, uint16x8_t, vreinterpretq, s16, u16)
 VREINTERPRET_IMPL(int32x4_t, uint32x4_t, vreinterpretq, s32, u32)
 } // namespace wrapper
 } // namespace arm_compute