COMPMID-812 Add NHWC data format support for NEON depthwise convolution (optimized case).

Change-Id: Icdfd6c02ed526daf4f59a4b76c7bbc1bc48fde74
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125938
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
index 08d8f8c..edda2cd 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
@@ -44,6 +44,7 @@
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
+    ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8,
                                                          DataType::QS16, DataType::F16,
                                                          DataType::QS32, DataType::S32, DataType::F32);
@@ -68,6 +69,7 @@
             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
         }
 
+        ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)));
         ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1);
     }
     else
@@ -79,6 +81,8 @@
     if((output != nullptr) && (output->total_size() != 0))
     {
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F32);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+
         if(is_data_type_fixed_point(input->data_type()))
         {
             ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS8 && output->data_type() != DataType::QS8, "Wrong data type for output");
@@ -101,6 +105,8 @@
 
 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *bias, ITensorInfo *output)
 {
+    ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
+
     bool         window_changed                    = false;
     unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(input->data_type());
 
@@ -138,8 +144,16 @@
         }
         else
         {
-            AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
-            window_changed = update_window_and_padding(win, input_access, bias_access);
+            if(input->data_layout() == DataLayout::NCHW)
+            {
+                AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
+                window_changed = update_window_and_padding(win, input_access, bias_access);
+            }
+            else
+            {
+                AccessWindowHorizontal bias_access(bias, 0, num_elems_processed_per_iteration);
+                window_changed = update_window_and_padding(win, input_access, bias_access);
+            }
         }
 
         input_access.set_valid_region(win, ValidRegion(Coordinates(), input->tensor_shape()));
@@ -253,6 +267,7 @@
 void output_stage(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
                   int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
 {
+    ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN);
     ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
     ARM_COMPUTE_UNUSED(result_shift);
     ARM_COMPUTE_UNUSED(result_offset_after_shift);
@@ -303,6 +318,66 @@
     }
 }
 
+template <typename T1, typename T2, bool in_place, bool has_bias>
+void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
+                       int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+{
+    ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
+    ARM_COMPUTE_UNUSED(result_shift);
+    ARM_COMPUTE_UNUSED(result_offset_after_shift);
+
+    Window window_bias = window;
+    window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
+    window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
+    window_bias.set(3, Window::Dimension(0, 0, 0));
+
+    Iterator in(input, window);
+    Iterator bi(bias, window_bias);
+
+    if(in_place) // In place accumulate
+    {
+        execute_window_loop(window, [&](const Coordinates & id)
+        {
+            // Get bias and pointer to input
+            const auto in_ptr   = reinterpret_cast<T1 *>(in.ptr());
+            const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
+
+            // Accumulate bias
+            if(has_bias)
+            {
+                internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
+            }
+            else
+            {
+                internal_vst1q(in_ptr, internal_vld1q(in_ptr));
+            }
+        },
+        in, bi);
+    }
+    else // Out of place accumulate
+    {
+        Iterator out(output, window);
+        execute_window_loop(window, [&](const Coordinates & id)
+        {
+            // Get bias and pointer to input
+            const auto in_ptr   = reinterpret_cast<T1 *>(in.ptr());
+            const auto out_ptr  = reinterpret_cast<T2 *>(out.ptr());
+            const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
+
+            // Accumulate bias
+            if(has_bias)
+            {
+                internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
+            }
+            else
+            {
+                internal_vst1q(out_ptr, internal_vld1q(in_ptr));
+            }
+        },
+        in, bi);
+    }
+}
+
 // QASYMM8 specializations
 template <>
 void output_stage<int32_t, uint8_t, false, true>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
@@ -415,61 +490,79 @@
     INEKernel::configure(win_config.second);
 
     // Set appropriate function
-    switch(input->info()->data_type())
+    if(input->info()->data_layout() == DataLayout::NCHW)
     {
-        case DataType::QS8:
+        switch(input->info()->data_type())
         {
-            if(bias == nullptr)
+            case DataType::QS8:
             {
-                _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, false> : &output_stage<qint8_t, qint8_t, false, false>;
+                if(bias == nullptr)
+                {
+                    _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, false> : &output_stage<qint8_t, qint8_t, false, false>;
+                }
+                else
+                {
+                    _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, true> : &output_stage<qint8_t, qint8_t, false, true>;
+                }
+                break;
             }
-            else
+            case DataType::QS16:
             {
-                _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, true> : &output_stage<qint8_t, qint8_t, false, true>;
+                if(bias != nullptr && bias->info()->data_type() == DataType::QS8)
+                {
+                    _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, true> : &output_stage<qint16_t, qint8_t, false, true>;
+                }
+                else if(bias == nullptr)
+                {
+                    _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, false> : &output_stage<qint16_t, qint8_t, false, false>;
+                }
+                else
+                {
+                    ARM_COMPUTE_ERROR("Not implemented");
+                }
+                break;
             }
-            break;
-        }
-        case DataType::QS16:
-        {
-            if(bias != nullptr && bias->info()->data_type() == DataType::QS8)
+            case DataType::QS32:
             {
-                _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, true> : &output_stage<qint16_t, qint8_t, false, true>;
+                _func = (output == nullptr) ? &output_stage<qint32_t, qint16_t, true, true> : &output_stage<qint32_t, qint16_t, false, true>;
+                break;
             }
-            else if(bias == nullptr)
+            case DataType::S32:
             {
-                _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, false> : &output_stage<qint16_t, qint8_t, false, false>;
+                _func = (bias == nullptr) ? &output_stage<int32_t, uint8_t, false, false> : &output_stage<int32_t, uint8_t, false, true>;
+                break;
             }
-            else
-            {
-                ARM_COMPUTE_ERROR("Not implemented");
-            }
-            break;
-        }
-        case DataType::QS32:
-        {
-            _func = (output == nullptr) ? &output_stage<qint32_t, qint16_t, true, true> : &output_stage<qint32_t, qint16_t, false, true>;
-            break;
-        }
-        case DataType::S32:
-        {
-            _func = (bias == nullptr) ? &output_stage<int32_t, uint8_t, false, false> : &output_stage<int32_t, uint8_t, false, true>;
-            break;
-        }
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-        case DataType::F16:
-        {
-            _func = (output == nullptr) ? &output_stage<float16_t, float16_t, true, true> : &output_stage<float16_t, float16_t, false, true>;
-            break;
-        }
+            case DataType::F16:
+            {
+                _func = (output == nullptr) ? &output_stage<float16_t, float16_t, true, true> : &output_stage<float16_t, float16_t, false, true>;
+                break;
+            }
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-        case DataType::F32:
-        {
-            _func = (output == nullptr) ? &output_stage<float, float, true, true> : &output_stage<float, float, false, true>;
-            break;
+            case DataType::F32:
+            {
+                _func = (output == nullptr) ? &output_stage<float, float, true, true> : &output_stage<float, float, false, true>;
+                break;
+            }
+            default:
+            {
+                ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
+            }
         }
-        default:
+    }
+    else
+    {
+        switch(input->info()->data_type())
         {
-            ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
+            case DataType::F32:
+            {
+                _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>;
+                break;
+            }
+            default:
+            {
+                ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
+            }
         }
     }
 }
diff --git a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
index f28ed71..8691fb9 100644
--- a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
@@ -37,7 +37,7 @@
 
 NEDepthwiseConvolutionLayer3x3::NEDepthwiseConvolutionLayer3x3()
     : _dwc_kernel(), _output_stage_kernel(), _border_handler(), _permute_input(), _permute_weights(), _permute_output(), _accumulator(), _input_nhwc(), _weights_hwio(), _output_nhwc(), _has_bias(false),
-      _is_quantized(false), _is_optimized(false), _are_weights_reshaped(false)
+      _is_quantized(false), _is_optimized(false), _are_weights_reshaped(false), _is_nchw(true), _is_first_run(true)
 {
 }
 
@@ -52,30 +52,38 @@
     _has_bias     = biases != nullptr;
     _is_optimized = NEDepthwiseConvolutionLayer3x3Kernel::is_optimized_execution_possible(input->info()->tensor_shape(),
                                                                                           conv_info,
-                                                                                          input->info()->data_type());
+                                                                                          input->info()->data_type(),
+                                                                                          input->info()->data_layout());
     _are_weights_reshaped = false;
+    _is_nchw              = input->info()->data_layout() == DataLayout::NCHW;
+
+    ARM_COMPUTE_ERROR_ON(!_is_optimized && !_is_nchw);
 
     if(_is_optimized)
     {
-        // Configure the function to transform the input tensor from NCHW -> NHWC
-        _permute_input.configure(input, &_input_nhwc, PermutationVector(2U, 0U, 1U));
+        if(_is_nchw)
+        {
+            // Configure the function to transform the input tensor from NCHW -> NHWC
+            _permute_input.configure(input, &_input_nhwc, PermutationVector(2U, 0U, 1U));
 
-        // Configure the function to transform the weights tensor from IHW -> HWI
-        _permute_weights.configure(weights, &_weights_hwio, PermutationVector(2U, 0U, 1U));
+            // Configure the function to transform the weights tensor from IHW -> HWI
+            _permute_weights.configure(weights, &_weights_hwio, PermutationVector(2U, 0U, 1U));
 
-        // Configure optimized depthwise
-        _dwc_kernel.configure(&_input_nhwc, &_weights_hwio, &_output_nhwc, conv_info, DataLayout::NHWC);
+            // Configure optimized depthwise
+            _dwc_kernel.configure(&_input_nhwc, &_weights_hwio, &_output_nhwc, conv_info, DataLayout::NHWC);
 
-        // Configure the function to transform the convoluted output to ACL's native ordering format NCHW
-        _permute_output.configure(&_output_nhwc, output, PermutationVector(1U, 2U, 0U));
+            // Configure the function to transform the convoluted output to ACL's native ordering format NCHW
+            _permute_output.configure(&_output_nhwc, output, PermutationVector(1U, 2U, 0U));
 
-        // Allocate tensors
-        _input_nhwc.allocator()->allocate();
-        _weights_hwio.allocator()->allocate();
-        _output_nhwc.allocator()->allocate();
-
-        // Create convolver (deferred)
-        _dwc_kernel.generate_convolver();
+            // Allocate tensors
+            _input_nhwc.allocator()->allocate();
+            _weights_hwio.allocator()->allocate();
+            _output_nhwc.allocator()->allocate();
+        }
+        else
+        {
+            _dwc_kernel.configure(input, weights, output, conv_info, DataLayout::NHWC);
+        }
     }
     else
     {
@@ -116,8 +124,15 @@
 
 void NEDepthwiseConvolutionLayer3x3::run()
 {
+    if(_is_first_run && _is_optimized)
+    {
+        _is_first_run = false;
+        // Create convolver (deferred)
+        _dwc_kernel.generate_convolver();
+    }
+
     // Permute weights in HWIO format if the optimized kernel will be executedd
-    if(!_are_weights_reshaped && _is_optimized)
+    if(!_are_weights_reshaped && _is_optimized && _is_nchw)
     {
         _are_weights_reshaped = true;
         _permute_weights.run();
@@ -126,8 +141,11 @@
     // Handle input
     if(_is_optimized)
     {
-        // Permute input to NHWC format execution
-        _permute_input.run();
+        if(_is_nchw)
+        {
+            // Permute input to NHWC format execution
+            _permute_input.run();
+        }
     }
     else
     {
@@ -139,7 +157,7 @@
     NEScheduler::get().schedule(&_dwc_kernel, Window::DimX);
 
     // Permute output to ACL's native NCHW format in case of NHWC execution
-    if(_is_optimized)
+    if(_is_optimized && _is_nchw)
     {
         _permute_output.run();
     }