COMPMID-812 Add NHWC data format support for NEON depthwise convolution (optimized case).

Change-Id: Icdfd6c02ed526daf4f59a4b76c7bbc1bc48fde74
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125938
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
index 08d8f8c..edda2cd 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
@@ -44,6 +44,7 @@
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
+    ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8,
                                                          DataType::QS16, DataType::F16,
                                                          DataType::QS32, DataType::S32, DataType::F32);
@@ -68,6 +69,7 @@
             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
         }
 
+        ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)));
         ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1);
     }
     else
@@ -79,6 +81,8 @@
     if((output != nullptr) && (output->total_size() != 0))
     {
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F32);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+
         if(is_data_type_fixed_point(input->data_type()))
         {
             ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS8 && output->data_type() != DataType::QS8, "Wrong data type for output");
@@ -101,6 +105,8 @@
 
 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *bias, ITensorInfo *output)
 {
+    ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
+
     bool         window_changed                    = false;
     unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(input->data_type());
 
@@ -138,8 +144,16 @@
         }
         else
         {
-            AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
-            window_changed = update_window_and_padding(win, input_access, bias_access);
+            if(input->data_layout() == DataLayout::NCHW)
+            {
+                AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
+                window_changed = update_window_and_padding(win, input_access, bias_access);
+            }
+            else
+            {
+                AccessWindowHorizontal bias_access(bias, 0, num_elems_processed_per_iteration);
+                window_changed = update_window_and_padding(win, input_access, bias_access);
+            }
         }
 
         input_access.set_valid_region(win, ValidRegion(Coordinates(), input->tensor_shape()));
@@ -253,6 +267,7 @@
 void output_stage(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
                   int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
 {
+    ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN);
     ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
     ARM_COMPUTE_UNUSED(result_shift);
     ARM_COMPUTE_UNUSED(result_offset_after_shift);
@@ -303,6 +318,66 @@
     }
 }
 
+template <typename T1, typename T2, bool in_place, bool has_bias>
+void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
+                       int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+{
+    ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
+    ARM_COMPUTE_UNUSED(result_shift);
+    ARM_COMPUTE_UNUSED(result_offset_after_shift);
+
+    Window window_bias = window;
+    window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
+    window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
+    window_bias.set(3, Window::Dimension(0, 0, 0));
+
+    Iterator in(input, window);
+    Iterator bi(bias, window_bias);
+
+    if(in_place) // In place accumulate
+    {
+        execute_window_loop(window, [&](const Coordinates & id)
+        {
+            // Get bias and pointer to input
+            const auto in_ptr   = reinterpret_cast<T1 *>(in.ptr());
+            const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
+
+            // Accumulate bias
+            if(has_bias)
+            {
+                internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
+            }
+            else
+            {
+                internal_vst1q(in_ptr, internal_vld1q(in_ptr));
+            }
+        },
+        in, bi);
+    }
+    else // Out of place accumulate
+    {
+        Iterator out(output, window);
+        execute_window_loop(window, [&](const Coordinates & id)
+        {
+            // Get bias and pointer to input
+            const auto in_ptr   = reinterpret_cast<T1 *>(in.ptr());
+            const auto out_ptr  = reinterpret_cast<T2 *>(out.ptr());
+            const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
+
+            // Accumulate bias
+            if(has_bias)
+            {
+                internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
+            }
+            else
+            {
+                internal_vst1q(out_ptr, internal_vld1q(in_ptr));
+            }
+        },
+        in, bi);
+    }
+}
+
 // QASYMM8 specializations
 template <>
 void output_stage<int32_t, uint8_t, false, true>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
@@ -415,61 +490,79 @@
     INEKernel::configure(win_config.second);
 
     // Set appropriate function
-    switch(input->info()->data_type())
+    if(input->info()->data_layout() == DataLayout::NCHW)
     {
-        case DataType::QS8:
+        switch(input->info()->data_type())
         {
-            if(bias == nullptr)
+            case DataType::QS8:
             {
-                _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, false> : &output_stage<qint8_t, qint8_t, false, false>;
+                if(bias == nullptr)
+                {
+                    _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, false> : &output_stage<qint8_t, qint8_t, false, false>;
+                }
+                else
+                {
+                    _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, true> : &output_stage<qint8_t, qint8_t, false, true>;
+                }
+                break;
             }
-            else
+            case DataType::QS16:
             {
-                _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, true> : &output_stage<qint8_t, qint8_t, false, true>;
+                if(bias != nullptr && bias->info()->data_type() == DataType::QS8)
+                {
+                    _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, true> : &output_stage<qint16_t, qint8_t, false, true>;
+                }
+                else if(bias == nullptr)
+                {
+                    _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, false> : &output_stage<qint16_t, qint8_t, false, false>;
+                }
+                else
+                {
+                    ARM_COMPUTE_ERROR("Not implemented");
+                }
+                break;
             }
-            break;
-        }
-        case DataType::QS16:
-        {
-            if(bias != nullptr && bias->info()->data_type() == DataType::QS8)
+            case DataType::QS32:
             {
-                _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, true> : &output_stage<qint16_t, qint8_t, false, true>;
+                _func = (output == nullptr) ? &output_stage<qint32_t, qint16_t, true, true> : &output_stage<qint32_t, qint16_t, false, true>;
+                break;
             }
-            else if(bias == nullptr)
+            case DataType::S32:
             {
-                _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, false> : &output_stage<qint16_t, qint8_t, false, false>;
+                _func = (bias == nullptr) ? &output_stage<int32_t, uint8_t, false, false> : &output_stage<int32_t, uint8_t, false, true>;
+                break;
             }
-            else
-            {
-                ARM_COMPUTE_ERROR("Not implemented");
-            }
-            break;
-        }
-        case DataType::QS32:
-        {
-            _func = (output == nullptr) ? &output_stage<qint32_t, qint16_t, true, true> : &output_stage<qint32_t, qint16_t, false, true>;
-            break;
-        }
-        case DataType::S32:
-        {
-            _func = (bias == nullptr) ? &output_stage<int32_t, uint8_t, false, false> : &output_stage<int32_t, uint8_t, false, true>;
-            break;
-        }
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-        case DataType::F16:
-        {
-            _func = (output == nullptr) ? &output_stage<float16_t, float16_t, true, true> : &output_stage<float16_t, float16_t, false, true>;
-            break;
-        }
+            case DataType::F16:
+            {
+                _func = (output == nullptr) ? &output_stage<float16_t, float16_t, true, true> : &output_stage<float16_t, float16_t, false, true>;
+                break;
+            }
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-        case DataType::F32:
-        {
-            _func = (output == nullptr) ? &output_stage<float, float, true, true> : &output_stage<float, float, false, true>;
-            break;
+            case DataType::F32:
+            {
+                _func = (output == nullptr) ? &output_stage<float, float, true, true> : &output_stage<float, float, false, true>;
+                break;
+            }
+            default:
+            {
+                ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
+            }
         }
-        default:
+    }
+    else
+    {
+        switch(input->info()->data_type())
         {
-            ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
+            case DataType::F32:
+            {
+                _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>;
+                break;
+            }
+            default:
+            {
+                ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
+            }
         }
     }
 }