COMPMID-2378: Sanitize GEMM configuration for NEON

Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1418
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
index f4485bc..e1af2d4 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
@@ -61,12 +61,16 @@
                 switch ((y + 5) - ymax) {
                     case 4:
                         outptr1 = dummyres;
+                        // fall through
                     case 3:
                         outptr2 = dummyres;
+                        // fall through
                     case 2:
                         outptr3 = dummyres;
+                        // fall through
                     case 1:
                         outptr4 = dummyres;
+                        // fall through
                     case 0:
                         outptr5 = dummyres;
                         break;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
index be23978..9fca4e3 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
@@ -63,16 +63,22 @@
                 switch ((y + 7) - ymax) {
                     case 6:
                         outptr1 = dummyres;
+                        // fall through
                     case 5:
                         outptr2 = dummyres;
+                        // fall through
                     case 4:
                         outptr3 = dummyres;
+                        // fall through
                     case 3:
                         outptr4 = dummyres;
+                        // fall through
                     case 2:
                         outptr5 = dummyres;
+                        // fall through
                     case 1:
                         outptr6 = dummyres;
+                        // fall through
                     case 0:
                         outptr7 = dummyres;
                         break;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
index 9e5eb88..0e638ee 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
@@ -66,16 +66,22 @@
                 switch ((y + 7) - ymax) {
                     case 6:
                         outptr1 = dummyres;
+                        // fall through
                     case 5:
                         outptr2 = dummyres;
+                        // fall through
                     case 4:
                         outptr3 = dummyres;
+                        // fall through
                     case 3:
                         outptr4 = dummyres;
+                        // fall through
                     case 2:
                         outptr5 = dummyres;
+                        // fall through
                     case 1:
                         outptr6 = dummyres;
+                        // fall through
                     case 0:
                         outptr7 = dummyres;
                         break;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp
index 3ed43b1..60cc2f3 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp
@@ -65,16 +65,22 @@
                 switch ((y + 7) - ymax) {
                     case 6:
                         outptr1 = dummyres;
+                        // fall through
                     case 5:
                         outptr2 = dummyres;
+                        // fall through
                     case 4:
                         outptr3 = dummyres;
+                        // fall through
                     case 3:
                         outptr4 = dummyres;
+                        // fall through
                     case 2:
                         outptr5 = dummyres;
+                        // fall through
                     case 1:
                         outptr6 = dummyres;
+                        // fall through
                     case 0:
                         outptr7 = dummyres;
                         break;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
index 35d4cc5..0212dfd 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
@@ -63,16 +63,22 @@
                 switch ((y + 7) - ymax) {
                     case 6:
                         outptr1 = dummyres;
+                        // fall through
                     case 5:
                         outptr2 = dummyres;
+                        // fall through
                     case 4:
                         outptr3 = dummyres;
+                        // fall through
                     case 3:
                         outptr4 = dummyres;
+                        // fall through
                     case 2:
                         outptr5 = dummyres;
+                        // fall through
                     case 1:
                         outptr6 = dummyres;
+                        // fall through
                     case 0:
                         outptr7 = dummyres;
                         break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
index 20ad301..a460fdf 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
@@ -60,12 +60,16 @@
                     /* Everything falls through in here */
                     case 4:
                         inptr1 = zerobuff;
+                        // fall through
                     case 3:
                         inptr2 = zerobuff;
+                        // fall through
                     case 2:
                         inptr3 = zerobuff;
+                        // fall through
                     case 1:
                         inptr4 = zerobuff;
+                        // fall through
                     case 0:
                         inptr5 = zerobuff;
                         break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
index 2f513a6..6a15fc4 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
@@ -57,8 +57,10 @@
                     /* Everything falls through in here */
                     case 2:
                         inptr1 = zerobuff;
+                        // fall through
                     case 1:
                         inptr2 = zerobuff;
+                        // fall through
                     case 0:
                         inptr3 = zerobuff;
                         break;
@@ -93,8 +95,10 @@
                     /* Everything falls through in here */
                     case 2:
                         inptr1 = zerobuff;
+                        // fall through
                     case 1:
                         inptr2 = zerobuff;
+                        // fall through
                     case 0:
                         inptr3 = zerobuff;
                         break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
index 27136d1..0028ab0 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
@@ -64,16 +64,22 @@
                     /* Everything falls through in here */
                     case 6:
                         inptr1 = zerobuff;
+                        // fall through
                     case 5:
                         inptr2 = zerobuff;
+                        // fall through
                     case 4:
                         inptr3 = zerobuff;
+                        // fall through
                     case 3:
                         inptr4 = zerobuff;
+                        // fall through
                     case 2:
                         inptr5 = zerobuff;
+                        // fall through
                     case 1:
                         inptr6 = zerobuff;
+                        // fall through
                     case 0:
                         inptr7 = zerobuff;
                         break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
index 54822c8..758c084 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
@@ -64,16 +64,22 @@
                     /* Everything falls through in here */
                     case 6:
                         inptr1 = zerobuff;
+                        // fall through
                     case 5:
                         inptr2 = zerobuff;
+                        // fall through
                     case 4:
                         inptr3 = zerobuff;
+                        // fall through
                     case 3:
                         inptr4 = zerobuff;
+                        // fall through
                     case 2:
                         inptr5 = zerobuff;
+                        // fall through
                     case 1:
                         inptr6 = zerobuff;
+                        // fall through
                     case 0:
                         inptr7 = zerobuff;
                         break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
index 0606330..de8e95a 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
@@ -64,16 +64,22 @@
                     /* Everything falls through in here */
                     case 6:
                         inptr1 = zerobuff;
+                        // fall through
                     case 5:
                         inptr2 = zerobuff;
+                        // fall through
                     case 4:
                         inptr3 = zerobuff;
+                        // fall through
                     case 3:
                         inptr4 = zerobuff;
+                        // fall through
                     case 2:
                         inptr5 = zerobuff;
+                        // fall through
                     case 1:
                         inptr6 = zerobuff;
+                        // fall through
                     case 0:
                         inptr7 = zerobuff;
                         break;
diff --git a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
index 0fc3610..d00f204 100644
--- a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
+++ b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -33,11 +33,11 @@
 using namespace arm_compute;
 
 INEGEMMWrapperKernel::INEGEMMWrapperKernel()
-    : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _window3d(), _window_shape()
+    : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _gemm_info(), _window3d(), _window_shape()
 {
 }
 
-INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c)
+INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c, const GEMMInfo &gemm_info)
 {
     Params p;
 
@@ -45,21 +45,30 @@
     ARM_COMPUTE_ERROR_ON_NULLPTR(b);
     ARM_COMPUTE_ERROR_ON_NULLPTR(c);
 
+    // Initalize params
     p.M       = c->info()->tensor_shape().y();
     p.N       = c->info()->tensor_shape().x();
     p.K       = a->info()->tensor_shape().x();
     p.multis  = b->info()->tensor_shape().z();
     p.batches = c->info()->tensor_shape().total_size_upper(2) / p.multis; //COMPMID-1423: Agree on and document the layout of gemm inputs/outputs
 
+    // Update M in case of GEMM3D for output
+    if(gemm_info.depth_output_gemm3d() != 0)
+    {
+        p.M       = c->info()->tensor_shape().y() * c->info()->tensor_shape().z();
+        p.batches = c->info()->tensor_shape().total_size_upper(3) / p.multis;
+    }
+
     return p;
 }
 
-void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta)
+void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info)
 {
-    _params = extract_parameters(a, b, c);
-    _a      = a;
-    _b      = b;
-    _c      = c;
+    _gemm_info = gemm_info;
+    _params    = extract_parameters(a, b, c, gemm_info);
+    _a         = a;
+    _b         = b;
+    _c         = c;
 
     _window3d     = configure_internal(alpha, beta);
     _window_shape = _window3d.shape();
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
index 26d9e99..6e30148 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
@@ -76,32 +76,34 @@
      * @param[in] transformed_a Reshaped tensor A.
      * @param[in] block_walker  Window representing the layout of the matrix's blocks.
      * @param[in] params        M, N, K sizes.
+     * @param[in] gemm_info     GEMM meta-data
      *
      * @return A wrapped specialized transformA kernel
      */
     virtual std::unique_ptr<NEGEMMInterleavedTransformAWrapper> instantiate_transformA(const ITensor                      *a,
                                                                                        ITensor                            *transformed_a,
                                                                                        const Window                       &block_walker,
-                                                                                       const INEGEMMWrapperKernel::Params &params) = 0;
+                                                                                       const INEGEMMWrapperKernel::Params &params,
+                                                                                       const GEMMInfo                     &gemm_info) = 0;
     /** Instantiate and configure a prepareB Kernel
      *
-     * @param transformed_a  Already reshaped tensor A.
-     * @param transformed_b  Already reshaped tensor B.
-     * @param tmp_c          Temporary buffer to be used to store intermediate results.
-     * @param c              Result tensor C.
-     * @param block_walker   Window containing iteration information for the M and batch dimensions.
-     * @param block_sizes    Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
-     * @param params         M, N, K sizes.
-     * @param alpha          Alpha value
-     * @param beta           Beta value
-     * @param pretranspose_b Is B also pretransposed ?
-     * @param num_threads    Maximum number of threads that might be used for the calculations.
+     * @param[in] transformed_a Already reshaped tensor A.
+     * @param[in] transformed_b Already reshaped tensor B.
+     * @param[in] tmp_c         Temporary buffer to be used to store intermediate results.
+     * @param[in] c             Result tensor C.
+     * @param[in] block_walker  Window containing iteration information for the M and batch dimensions.
+     * @param[in] block_sizes   Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
+     * @param[in] params        M, N, K sizes.
+     * @param[in] alpha         Alpha value
+     * @param[in] beta          Beta value
+     * @param[in] gemm_info     GEMM meta-data
+     * @param[in] num_threads   Maximum number of threads that might be used for the calculations.
      *
      * @return A wrapped specialized MatrixMultiply kernel
      */
     virtual std::unique_ptr<NEGEMMInterleavedMatrixMultiplyWrapper> instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c,
                                                                                                 const Window &block_walker, const BlockSizes &block_sizes,
-                                                                                                const INEGEMMWrapperKernel::Params &params, float alpha, float beta, bool pretranspose_b,
+                                                                                                const INEGEMMWrapperKernel::Params &params, float alpha, float beta, const GEMMInfo &gemm_info,
                                                                                                 unsigned int num_threads) = 0;
     /** Calculates the block sizes of a given strategy
      *
@@ -138,19 +140,20 @@
     std::unique_ptr<NEGEMMInterleavedTransformAWrapper> instantiate_transformA(const ITensor                      *a,
                                                                                ITensor                            *transformed_a,
                                                                                const Window                       &block_walker,
-                                                                               const INEGEMMWrapperKernel::Params &params) override
+                                                                               const INEGEMMWrapperKernel::Params &params,
+                                                                               const GEMMInfo                     &gemm_info) override
     {
         auto transform_a = support::cpp14::make_unique<NEGEMMInterleavedTransformAWrapperTemplate<strategy>>();
-        transform_a->configure(a, transformed_a, false, block_walker, params);
+        transform_a->configure(a, transformed_a, false, gemm_info.reinterpret_input_as_3d(), block_walker, params);
         return std::move(transform_a);
     }
     std::unique_ptr<NEGEMMInterleavedMatrixMultiplyWrapper> instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c,
                                                                                         const Window &block_walker, const BlockSizes &block_sizes,
-                                                                                        const INEGEMMWrapperKernel::Params &params, float alpha, float beta, bool pretranspose_b,
+                                                                                        const INEGEMMWrapperKernel::Params &params, float alpha, float beta, const GEMMInfo &gemm_info,
                                                                                         unsigned int num_threads) override
     {
         auto matrix_multiply = support::cpp14::make_unique<NEGEMMInterleavedMatrixMultiplyWrapperTemplate<strategy>>();
-        matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, pretranspose_b, alpha, beta, num_threads);
+        matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, gemm_info, alpha, beta, num_threads);
         return std::move(matrix_multiply);
     }
 
diff --git a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
index 97c20db..ecdb5a9 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
+++ b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
@@ -81,12 +81,20 @@
     TensorAccessor<To> b(*_b);
     TensorAccessor<Tr> c(*_c);
 
-    if(_a->info()->data_layout() == DataLayout::NHWC)
+    // Handle 3d input re-interpretation
+    if(_gemm_info.reinterpret_input_as_3d())
     {
-        // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
-        // the relevant multiple of the row stride.
-        const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _c->info()->dimension(1);
-        a.set_stride(2, nhwc_batch_stride);
+        Strides a_strides_as_3d = _a->info()->strides_in_bytes();
+        a_strides_as_3d.remove(Window::DimZ);
+        a.set_strides(a_strides_as_3d);
+    }
+
+    // Handle 3d output re-interpretation
+    if(_gemm_info.depth_output_gemm3d() != 0)
+    {
+        Strides c_strides_as_3d = _c->info()->strides_in_bytes();
+        c_strides_as_3d.remove(Window::DimZ);
+        c.set_strides(c_strides_as_3d);
     }
 
     unsigned int m_end = 0;