COMPMID-2801: Add support for QASYMM8_SIGNED in NEDirectConvolutionLayerOutputStageKernel

Change-Id: Ib047dd1024b8ecac60e2d368cb161ca418c933ff
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2503
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/arm_compute/core/KernelDescriptors.h b/arm_compute/core/KernelDescriptors.h
index f358153..d009ccc 100644
--- a/arm_compute/core/KernelDescriptors.h
+++ b/arm_compute/core/KernelDescriptors.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -83,5 +83,14 @@
     bool     is_log{ false };                      /**< Flag used to perform Log Softmax operation */
     DataType input_data_type{ DataType::UNKNOWN }; /**< Input tensor data type */
 };
+
+/** Descriptor used by the direct convolution layer output stage kernels */
+struct DirectConvolutionLayerOutputStageKernelInfo
+{
+    int32_t  result_fixedpoint_multiplier{ 0 };     /**< Result output stage multiplier used for quantizing */
+    int32_t  result_shift{ 0 };                     /**< Result output stage shift used for quantizing */
+    int32_t  result_offset_after_shift{ 0 };        /**< Result offset used for quantizing */
+    DataType output_data_type{ DataType::UNKNOWN }; /**< Output tensor data type to use if the output is not initialized */
+};
 } // namespace arm_compute
 #endif /* ARM_COMPUTE_CORE_KERNEL_DESCRIPTORS_H */
diff --git a/arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h b/arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h
index 3f41edc..b7632d7 100644
--- a/arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h
+++ b/arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,6 +24,7 @@
 #ifndef ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYEROUTPUTSTAGEKERNEL_H
 #define ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYEROUTPUTSTAGEKERNEL_H
 
+#include "arm_compute/core/KernelDescriptors.h"
 #include "arm_compute/core/NEON/INEKernel.h"
 
 namespace arm_compute
@@ -32,6 +33,8 @@
 /** NEON kernel to accumulate the biases, if provided, or downscale in case of quantized input.
  *
  * @note We assume bias to be shared
+ * @note For quantized computations (i.e. @p input of S32 type) the output data type for auto-initialization must be passed as part
+ *       of the @ref DirectConvolutionLayerOutputStageKernelInfo.
  */
 class NEDirectConvolutionLayerOutputStageKernel : public INEKernel
 {
@@ -54,32 +57,30 @@
     ~NEDirectConvolutionLayerOutputStageKernel() = default;
     /** Set the accumulate buffer and the biases of the kernel.
      *
-     * @param[in, out] input                        Input to add the bias to. If @p output is not specified then accumulation is done in-place.
-     *                                              Data type supported: F16/F32
-     * @param[in]      bias                         (Optional) The shared bias tensor to add. It must be 1D Tensor. Data type supported: Same as @p input
-     * @param[out]     output                       (Optional) If the output tensor is specified the accumulation is done out-of-place. (Defaults to nullptr)
-     *                                              Data type supported: F16/F32
-     * @param[in]      result_fixedpoint_multiplier (Optional) Fixed point value to be multiplied to each element of the input matrix once the result_offset has been added
-     * @param[in]      result_shift                 (Optional) Integer value used to round the result of the fixed point multiplication to nearest division by a power-of-two
-     * @param[in]      result_offset_after_shift    (Optional) Offset to be applied to result before converting it back to QASYMM8
+     * @param[in, out] input  Input to add the bias to. If @p output is not specified then accumulation is done in-place.
+     *                        Data type supported: F16/F32/S32
+     * @param[in]      bias   (Optional) The shared bias tensor to add. It must be 1D Tensor. Data type supported: Same as @p input
+     * @param[out]     output (Optional) If the output tensor is specified the accumulation is done out-of-place. (Defaults to nullptr)
+     *                        Note that in-place computation is only supported for F16/F32. For S32 this must not be nullptr.
+     *                        Data type supported: F16/F32 or QASYMM8/QASYMM8_SIGNED if @p input is S32
+     * @param[in]      info   (Optional) DirectConvolutionLayerOutputStageKernel descriptor metadata
      */
     void configure(ITensor *input, const ITensor *bias = nullptr, ITensor *output = nullptr,
-                   int result_fixedpoint_multiplier = 0, int result_shift = 0, int result_offset_after_shift = 0);
+                   const DirectConvolutionLayerOutputStageKernelInfo &info = DirectConvolutionLayerOutputStageKernelInfo());
     /** Static function to check if given info will lead to a valid configuration of @ref NEDirectConvolutionLayerOutputStageKernel
      *
-     * @param[in] input                        Input to add the bias to. If @p output is not specified then accumulation is done in-place.
-     *                                         Data type supported: F16/F32
-     * @param[in] bias                         (Optional) The shared bias tensor to add. It must be 1D Tensor. Data type supported: Same as @p input
-     * @param[in] output                       (Optional) If the output tensor is specified the accumulation is done out-of-place. (Defaults to nullptr)
-     *                                         Data type supported: F16/F32
-     * @param[in] result_fixedpoint_multiplier (Optional) Fixed point value to be multiplied to each element of the input matrix once the result_offset has been added
-     * @param[in] result_shift                 (Optional) Integer value used to round the result of the fixed point multiplication to nearest division by a power-of-two
-     * @param[in] result_offset_after_shift    (Optional) Offset to be applied to result before converting it back to QASYMM8
+     * @param[in] input  Input to add the bias to. If @p output is not specified then accumulation is done in-place.
+     *                   Data type supported: F16/F32/S32
+     * @param[in] bias   (Optional) The shared bias tensor to add. It must be 1D Tensor. Data type supported: Same as @p input
+     * @param[in] output (Optional) If the output tensor is specified the accumulation is done out-of-place. (Defaults to nullptr)
+     *                   Note that in-place computation is only supported for F16/F32. For S32 this must not be nullptr.
+     *                   Data type supported: F16/F32 or QASYMM8/QASYMM8_SIGNED if @p input is S32
+     * @param[in] info   (Optional) DirectConvolutionLayerOutputStageKernel descriptor metadata
      *
      * @return a status
      */
     static Status validate(const ITensorInfo *input, const ITensorInfo *bias = nullptr, const ITensorInfo *output = nullptr,
-                           int result_fixedpoint_multiplier = 0, int result_shift = 0, int result_offset_after_shift = 0);
+                           const DirectConvolutionLayerOutputStageKernelInfo &info = DirectConvolutionLayerOutputStageKernelInfo());
 
     // Inherited methods overridden:
     void run(const Window &window, const ThreadInfo &info) override;
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
index 8834d97..2f106a3 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,9 +30,11 @@
 #include "arm_compute/core/ITensor.h"
 #include "arm_compute/core/NEON/NEAsymm.h"
 #include "arm_compute/core/NEON/NEFixedPoint.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/misc/Traits.h"
 
 #include <arm_neon.h>
 #include <cstddef>
@@ -43,62 +45,68 @@
 namespace
 {
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output,
-                          int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+                          const DirectConvolutionLayerOutputStageKernelInfo &info)
 {
-    ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
-    ARM_COMPUTE_UNUSED(result_shift);
-    ARM_COMPUTE_UNUSED(result_offset_after_shift);
+    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
     ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8,
-                                                         DataType::F16,
-                                                         DataType::S32, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::S32, DataType::F32);
 
     if(bias != nullptr)
     {
-        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::F16, DataType::S32, DataType::F32);
-
-        if(is_data_type_quantized_asymmetric(input->data_type()))
-        {
-            ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::S32);
-        }
-        else
-        {
-            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
-        }
-
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
         ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)));
         ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1);
     }
 
+    if(input->data_type() == DataType::S32)
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(output == nullptr, "In-place computation not allowed for quantized output");
+    }
+
     // Checks performed when output is configured
     if((output != nullptr) && (output->total_size() != 0))
     {
-        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::F32);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
-
-        if(is_data_type_quantized_asymmetric(output->data_type()))
-        {
-            ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S32 && output->data_type() != DataType::QASYMM8, "Wrong data type for bias");
-        }
-        else
+        if(is_data_type_float(input->data_type()))
         {
             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
         }
+        else
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
+        }
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+    }
+    else if(input->data_type() == DataType::S32)
+    {
+        // In case of quantized computation and unconfigured output, the output data type must be provided through DirectConvolutionLayerOutputStageKernelInfo
+        ARM_COMPUTE_RETURN_ERROR_ON((info.output_data_type != DataType::QASYMM8) && (info.output_data_type != DataType::QASYMM8_SIGNED));
     }
 
     return Status{};
 }
 
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *bias, ITensorInfo *output)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *bias, ITensorInfo *output,
+                                                        const DirectConvolutionLayerOutputStageKernelInfo &info)
 {
     ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
 
+    const DataType data_type = input->data_type();
+
+    // Auto-initialize output output if required
+    if(output != nullptr)
+    {
+        // Work out expected output data type
+        const DataType output_dt = (data_type == DataType::S32) ? info.output_data_type : data_type;
+        // Output tensor auto initialization if not yet initialized
+        auto_init_if_empty(*output, input->clone()->set_data_type(output_dt));
+    }
+
     bool         window_changed                    = false;
-    unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(input->data_type());
+    unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(data_type);
 
     // Update processed elements when input is S32 (comes from quantization input)
-    if(input->data_type() == DataType::S32)
+    if(data_type == DataType::S32)
     {
         num_elems_processed_per_iteration = 16;
     }
@@ -150,107 +158,44 @@
     return std::make_pair(err, win);
 }
 
-// Internal load
-inline float32x4_t internal_vld1q(const float *in)
+template <typename T, bool has_bias>
+typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type
+output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
+                  int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
 {
-    return vld1q_f32(in);
-}
+    /** NEON vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
 
-// Internal store
-inline void internal_vst1q(float *p, const float32x4_t &v)
-{
-    vst1q_f32(p, v);
-}
-
-// Internal vdup
-inline float32x4_t internal_vdupq_n(float v)
-{
-    return vdupq_n_f32(v);
-}
-
-// Internal vadd
-inline float32x4_t internal_vqaddq(const float32x4_t &x, const float32x4_t &y)
-{
-    return vaddq_f32(x, y);
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-inline float16x8_t internal_vld1q(const float16_t *in)
-{
-    return vld1q_f16(in);
-}
-inline void internal_vst1q(float16_t *p, const float16x8_t &v)
-{
-    vst1q_f16(p, v);
-}
-inline float16x8_t internal_vdupq_n(float16_t v)
-{
-    return vdupq_n_f16(v);
-}
-inline float16x8_t internal_vqaddq(const float16x8_t &x, const float16x8_t &y)
-{
-    return vaddq_f16(x, y);
-}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-template <typename T1, typename T2, bool in_place, bool has_bias>
-void output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
-                       int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
-{
     ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN);
     ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
     ARM_COMPUTE_UNUSED(result_shift);
     ARM_COMPUTE_UNUSED(result_offset_after_shift);
 
     Iterator in(input, window);
-
-    if(in_place) // In place accumulate
+    Iterator out(output, window);
+    execute_window_loop(window, [&](const Coordinates & id)
     {
-        execute_window_loop(window, [&](const Coordinates & id)
-        {
-            // Get bias and pointer to input
-            const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
+        // Get bias and pointer to input
+        const auto in_ptr = reinterpret_cast<const T *>(in.ptr());
+        auto       v_in   = wrapper::vloadq(in_ptr);
 
-            // Accumulate bias
-            if(has_bias)
-            {
-                const auto vb = internal_vdupq_n(static_cast<T1>(*reinterpret_cast<const T2 *>(bias->ptr_to_element(Coordinates(id.z())))));
-                internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), vb));
-            }
-            else
-            {
-                internal_vst1q(in_ptr, internal_vld1q(in_ptr));
-            }
-        },
-        in);
-    }
-    else // Out of place accumulate
-    {
-        Iterator out(output, window);
-        execute_window_loop(window, [&](const Coordinates & id)
+        // Accumulate bias
+        if(has_bias)
         {
-            // Get bias and pointer to input
-            const auto in_ptr  = reinterpret_cast<const T1 *>(in.ptr());
-            const auto out_ptr = reinterpret_cast<T2 *>(out.ptr());
+            const auto vb = wrapper::vdup_n(*reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(id.z()))), ExactTagType{});
+            v_in          = wrapper::vadd(v_in, vb);
+        }
 
-            // Accumulate bias
-            if(has_bias)
-            {
-                const auto vb = internal_vdupq_n(static_cast<T1>(*reinterpret_cast<const T2 *>(bias->ptr_to_element(Coordinates(id.z())))));
-                internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), vb));
-            }
-            else
-            {
-                internal_vst1q(out_ptr, internal_vld1q(in_ptr));
-            }
-        },
-        in, out);
-    }
+        const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+        wrapper::vstore(out_ptr, v_in);
+    },
+    in, out);
 }
 
-template <typename T1, typename T2, bool in_place, bool has_bias>
-void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
-                       int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+template <typename T, bool has_bias>
+typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type
+output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
+                  int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
 {
     ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
     ARM_COMPUTE_UNUSED(result_shift);
@@ -263,59 +208,39 @@
 
     Iterator in(input, window);
     Iterator bi(bias, window_bias);
-
-    if(in_place) // In place accumulate
+    Iterator out(output, window);
+    execute_window_loop(window, [&](const Coordinates &)
     {
-        execute_window_loop(window, [&](const Coordinates &)
-        {
-            // Get bias and pointer to input
-            const auto in_ptr   = reinterpret_cast<T1 *>(in.ptr());
-            const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
+        // Get bias and pointer to input
+        const auto in_ptr = reinterpret_cast<const T *>(in.ptr());
+        auto       v_in   = wrapper::vloadq(in_ptr);
 
-            // Accumulate bias
-            if(has_bias)
-            {
-                internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
-            }
-            else
-            {
-                internal_vst1q(in_ptr, internal_vld1q(in_ptr));
-            }
-        },
-        in, bi);
-    }
-    else // Out of place accumulate
-    {
-        Iterator out(output, window);
-        execute_window_loop(window, [&](const Coordinates &)
+        // Accumulate bias
+        if(has_bias)
         {
-            // Get bias and pointer to input
-            const auto in_ptr   = reinterpret_cast<T1 *>(in.ptr());
-            const auto out_ptr  = reinterpret_cast<T2 *>(out.ptr());
-            const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
+            const auto bias_ptr = reinterpret_cast<T *>(bi.ptr());
+            v_in                = wrapper::vadd(v_in, wrapper::vloadq(bias_ptr));
+        }
 
-            // Accumulate bias
-            if(has_bias)
-            {
-                internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
-            }
-            else
-            {
-                internal_vst1q(out_ptr, internal_vld1q(in_ptr));
-            }
-        },
-        in, bi, out);
-    }
+        const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+        wrapper::vstore(out_ptr, v_in);
+
+    },
+    in, bi, out);
 }
 
-// QASYMM8 specializations
-template <>
-void output_stage_nchw<int32_t, uint8_t, false, true>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
-                                                      int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+// Quantized case
+template < typename TOut, bool has_bias, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 >
+void output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
+                       int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
 {
+    using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
+    using TagType    = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
+
     const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
-    uint8x16_t      min                           = vdupq_n_u8(0);
-    uint8x16_t      max                           = vdupq_n_u8(255);
+
+    const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
+    const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
 
     Iterator in(input, window);
     Iterator out(output, window);
@@ -327,68 +252,44 @@
         int32x4x4_t v_in =
         {
             {
-                vld1q_s32(in_ptr),
-                vld1q_s32(in_ptr + 4),
-                vld1q_s32(in_ptr + 8),
-                vld1q_s32(in_ptr + 12)
+                wrapper::vloadq(in_ptr),
+                wrapper::vloadq(in_ptr + 4),
+                wrapper::vloadq(in_ptr + 8),
+                wrapper::vloadq(in_ptr + 12)
             }
         };
 
         // Accumulate bias
-        const auto vb = vdupq_n_s32(*reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z()))));
-        v_in =
+        if(has_bias)
         {
+            const auto vb = wrapper::vdup_n(*reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z()))), TagType{});
+            v_in =
             {
-                vaddq_s32(v_in.val[0], vb),
-                vaddq_s32(v_in.val[1], vb),
-                vaddq_s32(v_in.val[2], vb),
-                vaddq_s32(v_in.val[3], vb)
-            }
-        };
+                {
+                    wrapper::vadd(v_in.val[0], vb),
+                    wrapper::vadd(v_in.val[1], vb),
+                    wrapper::vadd(v_in.val[2], vb),
+                    wrapper::vadd(v_in.val[3], vb)
+                }
+            };
+        }
 
-        const auto out_ptr = reinterpret_cast<uint8_t *>(out.ptr());
-        vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
+        const auto out_ptr = reinterpret_cast<TOut *>(out.ptr());
+        wrapper::vstore(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
     },
     in, out);
 }
-template <>
-void output_stage_nchw<int32_t, uint8_t, false, false>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
-                                                       int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+template < typename TOut, bool has_bias, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 >
+void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
+                       int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
 {
-    ARM_COMPUTE_UNUSED(bias);
+    using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
+    using TagType    = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
 
     const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
-    uint8x16_t      min                           = vdupq_n_u8(0);
-    uint8x16_t      max                           = vdupq_n_u8(255);
 
-    Iterator in(input, window);
-    Iterator out(output, window);
-    execute_window_loop(window, [&](const Coordinates &)
-    {
-        // Get bias and pointer to input
-        const auto  in_ptr = reinterpret_cast<int32_t *>(in.ptr());
-        int32x4x4_t v_in =
-        {
-            {
-                vld1q_s32(in_ptr),
-                vld1q_s32(in_ptr + 4),
-                vld1q_s32(in_ptr + 8),
-                vld1q_s32(in_ptr + 12)
-            }
-        };
-
-        const auto out_ptr = reinterpret_cast<uint8_t *>(out.ptr());
-        vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
-    },
-    in, out);
-}
-template <>
-void output_stage_nhwc<int32_t, uint8_t, false, true>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
-                                                      int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
-{
-    const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
-    uint8x16_t      min                           = vdupq_n_u8(0);
-    uint8x16_t      max                           = vdupq_n_u8(255);
+    const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
+    const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
 
     Window window_bias = window;
     window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
@@ -402,57 +303,33 @@
     execute_window_loop(window, [&](const Coordinates &)
     {
         // Get bias and pointer to input
-        const auto in_ptr   = reinterpret_cast<int32_t *>(in.ptr());
-        const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr());
+        const auto  in_ptr = reinterpret_cast<int32_t *>(in.ptr());
+        int32x4x4_t v_in =
+        {
+            {
+                wrapper::vloadq(in_ptr),
+                wrapper::vloadq(in_ptr + 4),
+                wrapper::vloadq(in_ptr + 8),
+                wrapper::vloadq(in_ptr + 12),
+            }
+        };
 
         // Accumulate bias
-        int32x4x4_t v_in =
+        if(has_bias)
         {
-            {
-                vaddq_s32(vld1q_s32(in_ptr), vld1q_s32(bias_ptr)),
-                vaddq_s32(vld1q_s32(in_ptr + 4), vld1q_s32(bias_ptr + 4)),
-                vaddq_s32(vld1q_s32(in_ptr + 8), vld1q_s32(bias_ptr + 8)),
-                vaddq_s32(vld1q_s32(in_ptr + 12), vld1q_s32(bias_ptr + 12))
-            }
-        };
+            const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr());
 
-        const auto out_ptr = out.ptr();
-        vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
+            wrapper::vadd(v_in.val[0], wrapper::vloadq(bias_ptr));
+            wrapper::vadd(v_in.val[1], wrapper::vloadq(bias_ptr + 4));
+            wrapper::vadd(v_in.val[2], wrapper::vloadq(bias_ptr + 8));
+            wrapper::vadd(v_in.val[3], wrapper::vloadq(bias_ptr + 12));
+        }
+
+        const auto out_ptr = reinterpret_cast<TOut *>(out.ptr());
+        wrapper::vstore(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
     },
     in, bi, out);
 }
-template <>
-void output_stage_nhwc<int32_t, uint8_t, false, false>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
-                                                       int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
-{
-    ARM_COMPUTE_UNUSED(bias);
-
-    const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
-    uint8x16_t      min                           = vdupq_n_u8(0);
-    uint8x16_t      max                           = vdupq_n_u8(255);
-
-    Iterator in(input, window);
-    Iterator out(output, window);
-    execute_window_loop(window, [&](const Coordinates &)
-    {
-        // Get pointer to input
-        const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
-
-        int32x4x4_t v_in =
-        {
-            {
-                vld1q_s32(in_ptr),
-                vld1q_s32(in_ptr + 4),
-                vld1q_s32(in_ptr + 8),
-                vld1q_s32(in_ptr + 12)
-            }
-        };
-
-        const auto out_ptr = out.ptr();
-        vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
-    },
-    in, out);
-}
 } // namespace
 
 NEDirectConvolutionLayerOutputStageKernel::NEDirectConvolutionLayerOutputStageKernel()
@@ -461,37 +338,27 @@
 }
 
 void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const ITensor *bias, ITensor *output,
-                                                          int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+                                                          const DirectConvolutionLayerOutputStageKernelInfo &info)
 {
-    ARM_COMPUTE_ERROR_ON_NULLPTR(input);
-
-    // Auto-initialize output output if required
-    if(output != nullptr)
-    {
-        // Work out expected output data type
-        const DataType output_dt = (input->info()->data_type() == DataType::S32) ? DataType::QASYMM8 : input->info()->data_type();
-        // Output tensor auto initialization if not yet initialized
-        auto_init_if_empty(*output->info(), input->info()->clone()->set_data_type(output_dt));
-    }
-
     // Perform validation step
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info(),
-                                                  result_fixedpoint_multiplier, result_shift, result_offset_after_shift));
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input);
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info(), info));
 
     _func                         = nullptr;
     _bias                         = bias;
     _input                        = input;
-    _output                       = output;
-    _result_fixedpoint_multiplier = result_fixedpoint_multiplier;
-    _result_shift                 = result_shift;
-    _result_offset_after_shift    = result_offset_after_shift;
+    _output                       = (output != nullptr) ? output : input;
+    _result_fixedpoint_multiplier = info.result_fixedpoint_multiplier;
+    _result_shift                 = info.result_shift;
+    _result_offset_after_shift    = info.result_offset_after_shift;
 
     // Configure kernel window
-    auto win_config = validate_and_configure_window(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info());
+    auto win_config = validate_and_configure_window(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info(), info);
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     INEKernel::configure(win_config.second);
 
-    const bool has_bias = bias != nullptr;
+    const bool has_bias          = bias != nullptr;
+    const bool is_qasymm8_signed = (output != nullptr) ? is_data_type_quantized_asymmetric_signed(output->info()->data_type()) : false;
 
     // Set appropriate function
     if(input->info()->data_layout() == DataLayout::NCHW)
@@ -500,33 +367,26 @@
         {
             case DataType::S32:
             {
-                _func = (bias == nullptr) ? &output_stage_nchw<int32_t, uint8_t, false, false> : &output_stage_nchw<int32_t, uint8_t, false, true>;
+                if(is_qasymm8_signed)
+                {
+                    _func = (has_bias) ? &output_stage_nchw<int8_t, true> : &output_stage_nchw<int8_t, false>;
+                }
+                else
+                {
+                    _func = (has_bias) ? &output_stage_nchw<uint8_t, true> : &output_stage_nchw<uint8_t, false>;
+                }
                 break;
             }
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
             case DataType::F16:
             {
-                if(has_bias)
-                {
-                    _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, true> : &output_stage_nchw<float16_t, float16_t, false, true>;
-                }
-                else
-                {
-                    _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, false> : &output_stage_nchw<float16_t, float16_t, false, false>;
-                }
+                _func = (has_bias) ? &output_stage_nchw<float16_t, true> : &output_stage_nchw<float16_t, false>;
                 break;
             }
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
             case DataType::F32:
             {
-                if(has_bias)
-                {
-                    _func = (output == nullptr) ? &output_stage_nchw<float, float, true, true> : &output_stage_nchw<float, float, false, true>;
-                }
-                else
-                {
-                    _func = (output == nullptr) ? &output_stage_nchw<float, float, true, false> : &output_stage_nchw<float, float, false, false>;
-                }
+                _func = (has_bias) ? &output_stage_nchw<float, true> : &output_stage_nchw<float, false>;
                 break;
             }
             default:
@@ -541,33 +401,26 @@
         {
             case DataType::S32:
             {
-                _func = (bias == nullptr) ? &output_stage_nhwc<int32_t, uint8_t, false, false> : &output_stage_nhwc<int32_t, uint8_t, false, true>;
+                if(is_qasymm8_signed)
+                {
+                    _func = (has_bias) ? &output_stage_nhwc<int8_t, true> : &output_stage_nhwc<int8_t, false>;
+                }
+                else
+                {
+                    _func = (has_bias) ? &output_stage_nhwc<uint8_t, true> : &output_stage_nhwc<uint8_t, false>;
+                }
                 break;
             }
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
             case DataType::F16:
             {
-                if(has_bias)
-                {
-                    _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, true> : &output_stage_nhwc<float16_t, float16_t, false, true>;
-                }
-                else
-                {
-                    _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, false> : &output_stage_nhwc<float16_t, float16_t, false, false>;
-                }
+                _func = (has_bias) ? &output_stage_nhwc<float16_t, true> : &output_stage_nhwc<float16_t, false>;
                 break;
             }
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
             case DataType::F32:
             {
-                if(has_bias)
-                {
-                    _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>;
-                }
-                else
-                {
-                    _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, false> : &output_stage_nhwc<float, float, false, false>;
-                }
+                _func = (has_bias) ? &output_stage_nhwc<float, true> : &output_stage_nhwc<float, false>;
                 break;
             }
             default:
@@ -579,10 +432,14 @@
 }
 
 Status NEDirectConvolutionLayerOutputStageKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output,
-                                                           int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+                                                           const DirectConvolutionLayerOutputStageKernelInfo &info)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output, result_fixedpoint_multiplier, result_shift, result_offset_after_shift));
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), bias == nullptr ? nullptr : bias->clone().get(), output == nullptr ? nullptr : output->clone().get()).first);
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output, info));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
+                                                              bias == nullptr ? nullptr : bias->clone().get(),
+                                                              output == nullptr ? nullptr : output->clone().get(),
+                                                              info)
+                                .first);
 
     return Status{};
 }
diff --git a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
index ddcc71f..0320002 100644
--- a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -67,7 +67,9 @@
 
         if(is_quantized)
         {
-            ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayerOutputStageKernel::validate(&accumulator, biases, output));
+            DirectConvolutionLayerOutputStageKernelInfo direct_conv_info;
+            direct_conv_info.output_data_type = input->data_type();
+            ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayerOutputStageKernel::validate(&accumulator, biases, output, direct_conv_info));
         }
     }
     else
@@ -196,7 +198,13 @@
         int32_t output_multiplier;
         int32_t output_shift;
         quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
-        _output_stage_kernel.configure(&_accumulator, biases, _is_nchw ? output : &_permuted_output, output_multiplier, output_shift, oq_info.offset);
+
+        DirectConvolutionLayerOutputStageKernelInfo direct_conv_info;
+        direct_conv_info.result_fixedpoint_multiplier = output_multiplier;
+        direct_conv_info.result_shift                 = output_shift;
+        direct_conv_info.result_offset_after_shift    = oq_info.offset;
+        direct_conv_info.output_data_type             = input->info()->data_type();
+        _output_stage_kernel.configure(&_accumulator, biases, _is_nchw ? output : &_permuted_output, direct_conv_info);
         _accumulator.allocator()->allocate();
     }
     else if(_has_bias)