COMPMID-1519: Add support for 3D input/output in CLGEMMLowpOutputStage

Change-Id: I637add70310d2da4d82b236a6352af9d33be17a1
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/149706
Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com>
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
Tested-by: bsgcomp <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index e52f1ea..e8124e7 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -2222,17 +2222,29 @@
  * @param[in]  dst_step_y                           dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_stride_z                         Stride of the source tensor in Z dimension (in bytes)
  * @param[in]  dst_step_z                           src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in]  dst_stride_w                         Stride of the source tensor in W dimension (in bytes)
+ * @param[in]  dst_step_w                           src_stride_w * number of elements along W processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes    The offset of the first element in the destination tensor
  */
 __kernel void gemmlowp_output_stage_quantize_down_fixedpoint(TENSOR3D_DECLARATION(src),
 #if defined(ADD_BIAS)
                                                              VECTOR_DECLARATION(biases),
 #endif // defined(ADD_BIAS)
+#if defined(DST_HEIGHT)
+                                                             TENSOR4D_DECLARATION(dst))
+#else  // defined(DST_HEIGHT)
                                                              TENSOR3D_DECLARATION(dst))
+#endif // defined(DST_HEIGHT)
 {
     // Compute source and destination addresses
     Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+#if defined(DST_HEIGHT)
+    Tensor4D dst = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(dst, 1);
+    dst.ptr += get_global_id(0) * dst_step_x + (get_global_id(1) % DST_HEIGHT) * dst_step_y + (get_global_id(1) / DST_HEIGHT) * dst_step_z + get_global_id(2) * dst_step_w;
+#else  // defined(DST_HEIGHT)
     Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
+#endif // defined(DST_HEIGHT)
+
 #if defined(ADD_BIAS)
     Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
 #endif // defined(ADD_BIAS)
diff --git a/src/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp b/src/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp
index 875e26d..d403d67 100644
--- a/src/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp
@@ -27,9 +27,12 @@
 #include "arm_compute/core/CL/ICLTensor.h"
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+
 #include "support/ToolchainSupport.h"
 
 using namespace arm_compute;
@@ -38,7 +41,8 @@
 {
 namespace
 {
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output,
+                          int min, int max, unsigned int output_3d_depth)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::S32);
     ARM_COMPUTE_RETURN_ERROR_ON(max > 255);
@@ -54,8 +58,10 @@
 
     if(output->total_size() != 0)
     {
+        const TensorShape output_shape       = arm_compute::misc::shape_calculator::compute_output_stage_shape(*input, output_3d_depth, true);
+        const TensorInfo  tensor_info_output = output->clone()->set_tensor_shape(output_shape);
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
     }
 
     return Status{};
@@ -66,7 +72,7 @@
     constexpr unsigned int num_elems_processed_per_iteration = 16;
 
     // Configure kernel window
-    Window win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
+    Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
 
     AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
 
@@ -75,8 +81,9 @@
 
     if(output->total_size() != 0)
     {
+        Window                 win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
         AccessWindowHorizontal output_result_access(output, 0, num_elems_processed_per_iteration);
-        window_changed = window_changed || update_window_and_padding(win, output_result_access);
+        window_changed = window_changed || update_window_and_padding(win_out, output_result_access);
 
         output_result_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
     }
@@ -96,14 +103,15 @@
 } // namespace arm_compute
 
 CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel()
-    : _input(nullptr), _bias(nullptr), _output(nullptr)
+    : _input(nullptr), _bias(nullptr), _output(nullptr), _reinterpret_as_3d(false)
 {
 }
 
-Status CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max)
+Status CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output,
+                                                                           int min, int max, unsigned int output_3d_depth)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output, min, max));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output, min, max, output_3d_depth));
     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
                                                               (bias != nullptr) ? bias->clone().get() : nullptr,
                                                               output->clone().get())
@@ -112,24 +120,24 @@
     return Status{};
 }
 
-void CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::configure(const ICLTensor *input, const ICLTensor *bias, ICLTensor *output, int result_fixedpoint_multiplier, int result_shift,
-                                                                          int result_offset_after_shift, int min, int max)
+void CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::configure(const ICLTensor *input, const ICLTensor *bias, ICLTensor *output,
+                                                                          int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift,
+                                                                          int min, int max, unsigned int output_3d_depth)
 {
     // Perform validate step
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
 
     // Output auto inizialitation if not yet initialized
-    auto_init_if_empty(*output->info(), input->info()->clone()->set_data_type(DataType::QASYMM8));
+    const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_output_stage_shape(*input->info(), output_3d_depth, true);
+    auto_init_if_empty(*output->info(), input->info()->clone()->set_data_type(DataType::QASYMM8).set_tensor_shape(output_shape));
 
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(),
-                                                  (bias != nullptr) ? bias->info() : nullptr,
-                                                  output->info(),
-                                                  min,
-                                                  max));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (bias != nullptr) ? bias->info() : nullptr, output->info(),
+                                                  min, max, output_3d_depth));
 
-    _input  = input;
-    _bias   = bias;
-    _output = output;
+    _input             = input;
+    _bias              = bias;
+    _output            = output;
+    _reinterpret_as_3d = output_3d_depth > 1;
 
     // Set the arguments to pass at compile time
     CLBuildOptions build_opts;
@@ -139,6 +147,7 @@
     build_opts.add_option_if((min != 0) && (min != max), "-DMIN_BOUND=" + support::cpp11::to_string(min));
     build_opts.add_option_if((max != 255) && (min != max), "-DMAX_BOUND=" + support::cpp11::to_string(max));
     build_opts.add_option_if(bias != nullptr, "-DADD_BIAS");
+    build_opts.add_option_if(_reinterpret_as_3d, "-DDST_HEIGHT=" + support::cpp11::to_string(input->info()->tensor_shape().y() / output_3d_depth));
 
     // Create kernel
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemmlowp_output_stage_quantize_down_fixedpoint", build_opts.options()));
@@ -154,9 +163,11 @@
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
 
+    // Create input window
     Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
     Window slice     = collapsed.first_slice_window_3D();
 
+    // Setup bias slice
     unsigned int idx1 = num_arguments_per_3D_tensor();
     if(_bias != nullptr)
     {
@@ -166,12 +177,32 @@
         add_1D_tensor_argument(idx1, _bias, biases_slice);
     }
 
-    do
+    if(_reinterpret_as_3d)
     {
-        unsigned int idx = 0;
-        add_3D_tensor_argument(idx, _input, slice);
-        add_3D_tensor_argument(idx1, _output, slice);
-        enqueue(queue, *this, slice);
+        // Create output window
+        Window window_out;
+        window_out.use_tensor_dimensions(_output->info()->tensor_shape());
+        Window collapsed_out = window_out.collapse_if_possible(window_out, 3);
+        Window slice_out     = collapsed.first_slice_window_4D();
+
+        do
+        {
+            unsigned int idx = 0;
+            add_3D_tensor_argument(idx, _input, slice);
+            add_4D_tensor_argument(idx1, _output, slice_out);
+            enqueue(queue, *this, slice);
+        }
+        while(collapsed.slide_window_slice_3D(slice) && collapsed_out.slide_window_slice_4D(slice_out));
     }
-    while(collapsed.slide_window_slice_3D(slice));
+    else
+    {
+        do
+        {
+            unsigned int idx = 0;
+            add_3D_tensor_argument(idx, _input, slice);
+            add_3D_tensor_argument(idx1, _output, slice);
+            enqueue(queue, *this, slice);
+        }
+        while(collapsed.slide_window_slice_3D(slice));
+    }
 }
diff --git a/src/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp b/src/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp
index 0196bac..7cd50cc 100644
--- a/src/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp
@@ -28,6 +28,7 @@
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/ITensor.h"
 #include "arm_compute/core/NEON/NEAsymm.h"
+#include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
@@ -43,9 +44,8 @@
 namespace
 {
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output,
-                          int min, int max, unsigned int gemm_3d_depth)
+                          int min, int max, unsigned int output_3d_depth)
 {
-    ARM_COMPUTE_UNUSED(gemm_3d_depth);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::S32);
     ARM_COMPUTE_RETURN_ERROR_ON(max > 255);
     ARM_COMPUTE_RETURN_ERROR_ON(min < 0 || min > max);
@@ -60,21 +60,10 @@
 
     if(output->total_size() != 0)
     {
-        const TensorShape ref_shape    = output->tensor_shape();
-        const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_output_stage_shape(*input, gemm_3d_depth);
-        // Check in case of mismatching dimensions when permuting, usually in case of 1x1xC input shapes
-        if(output_shape.num_dimensions() != ref_shape.num_dimensions() && ref_shape.num_dimensions() < 4)
-        {
-            for(unsigned int i = output_shape.num_dimensions(); i < ref_shape.num_dimensions(); ++i)
-            {
-                ARM_COMPUTE_RETURN_ERROR_ON(ref_shape[i] != 1);
-            }
-        }
-        else
-        {
-            ARM_COMPUTE_RETURN_ERROR_ON(output->tensor_shape() != output_shape);
-        }
+        const TensorShape output_shape       = arm_compute::misc::shape_calculator::compute_output_stage_shape(*input, output_3d_depth);
+        const TensorInfo  tensor_info_output = output->clone()->set_tensor_shape(output_shape);
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
     }
 
     return Status{};
@@ -160,7 +149,7 @@
     const int          window_step_x  = 16;
     const auto         window_start_x = static_cast<int>(window.x().start());
     const auto         window_end_x   = static_cast<int>(window.x().end());
-    const unsigned int gemm_3d_height = _input->info()->tensor_shape().y() / _gemm_3d_depth;
+    const unsigned int gemm_3d_height = _input->info()->tensor_shape().y() / _output_3d_depth;
 
     Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
     win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
@@ -177,7 +166,7 @@
         {
             // Calculate output coordinates
             Coordinates out_coords = id;
-            if(_gemm_3d_depth != 1)
+            if(_output_3d_depth != 1)
             {
                 out_coords.set(Window::DimY, id.y() % gemm_3d_height);
                 out_coords.set(Window::DimZ, id.y() / gemm_3d_height);
@@ -240,10 +229,10 @@
         {
             // Calculate output coordinates
             Coordinates out_coords = id;
-            if(_gemm_3d_depth != 1)
+            if(_output_3d_depth != 1)
             {
-                out_coords.set(Window::DimY, id.y() % _gemm_3d_depth);
-                out_coords.set(Window::DimZ, id.y() / _gemm_3d_depth);
+                out_coords.set(Window::DimY, id.y() % _output_3d_depth);
+                out_coords.set(Window::DimZ, id.y() / _output_3d_depth);
                 out_coords.set(3, id.z());
             }
             uint8_t *out_ptr = _output->ptr_to_element(out_coords);
@@ -279,22 +268,22 @@
 }
 
 NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel()
-    : _func(nullptr), _input(nullptr), _bias(nullptr), _output(nullptr), _result_fixedpoint_multiplier(0), _result_shift(0), _result_offset_after_shift(0), _min(0), _max(0), _gemm_3d_depth(1)
+    : _func(nullptr), _input(nullptr), _bias(nullptr), _output(nullptr), _result_fixedpoint_multiplier(0), _result_shift(0), _result_offset_after_shift(0), _min(0), _max(0), _output_3d_depth(1)
 {
 }
 
 void NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::configure(const ITensor *input, const ITensor *bias, ITensor *output, int result_fixedpoint_multiplier, int result_shift,
-                                                                          int result_offset_after_shift, int min, int max, unsigned int gemm_3d_depth)
+                                                                          int result_offset_after_shift, int min, int max, unsigned int output_3d_depth)
 {
     // Perform validate step
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
 
     // Output auto inizialitation if not yet initialized
-    const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_output_stage_shape(*input->info(), gemm_3d_depth);
+    const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_output_stage_shape(*input->info(), output_3d_depth);
     auto_init_if_empty(*output->info(), input->info()->clone()->set_data_type(DataType::QASYMM8).set_tensor_shape(output_shape));
 
     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (bias != nullptr) ? bias->info() : nullptr, output->info(),
-                                                  min, max, gemm_3d_depth));
+                                                  min, max, output_3d_depth));
 
     _input                        = input;
     _bias                         = bias;
@@ -304,7 +293,7 @@
     _result_offset_after_shift    = result_offset_after_shift;
     _min                          = min;
     _max                          = max;
-    _gemm_3d_depth                = gemm_3d_depth;
+    _output_3d_depth              = output_3d_depth;
 
     // Configure kernel window
     auto win_config = validate_and_configure_window(input->info(), (bias != nullptr) ? bias->info() : nullptr, output->info());
@@ -316,10 +305,10 @@
     _func                      = is_bounded_relu ? &NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::run<true> : &NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::run<false>;
 }
 
-Status NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max, unsigned int gemm_3d_depth)
+Status NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max, unsigned int output_3d_depth)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output, min, max, gemm_3d_depth));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output, min, max, output_3d_depth));
     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
                                                               (bias != nullptr) ? bias->clone().get() : nullptr,
                                                               output->clone().get())
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index bd5e969..f41a12a 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -92,8 +92,8 @@
 
 CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
     : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _col2im_kernel(), _activationlayer_function(),
-      _add_bias_kernel(), _reshape_layer(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _append_bias(false),
-      _skip_im2col(false), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
+      _add_bias_kernel(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false),
+      _skip_col2im(false), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
 {
 }
 
@@ -102,6 +102,9 @@
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
     ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info(), gemm_3d_depth, _skip_im2col));
 
+    const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
+                                         gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
+
     if(_is_quantized)
     {
         // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -112,7 +115,7 @@
         input->info()->set_quantization_info(QuantizationInfo(input_quantization_info.scale, -input_quantization_info.offset));
         weights->info()->set_quantization_info(QuantizationInfo(weights_quantization_info.scale, -weights_quantization_info.offset));
 
-        _mm_gemmlowp.configure(input, weights, output, GEMMInfo(false, false, true /* Reshape weights only for the first run*/));
+        _mm_gemmlowp.configure(input, weights, output, gemm_info);
 
         // Revert back QuantizatioInfo as input and weights could be used in other convolution layers
         input->info()->set_quantization_info(input_quantization_info);
@@ -121,8 +124,7 @@
     else
     {
         // Configure matrix multiply function
-        _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth,
-                                                                                 _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */));
+        _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, gemm_info);
     }
 }
 
@@ -130,10 +132,11 @@
 {
     const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
 
+    const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
+                                         gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
+
     if(is_quantized)
     {
-        const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */);
-
         // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
         // Extract and negate input and weights offset
         const QuantizationInfo input_quantization_info   = input->quantization_info();
@@ -149,8 +152,6 @@
     }
     else
     {
-        const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
-
         // Perform validation step on Matrix multiply function
         return CLGEMM::validate(input, weights, nullptr, output, 1.0f, 0.0f, gemm_info);
     }
@@ -175,6 +176,7 @@
     const DataLayout data_layout = input->info()->data_layout();
     const int        idx_width   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
     const int        idx_height  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+    const int        idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
     const int        idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
 
     const unsigned int kernel_width  = weights->info()->dimension(idx_width);
@@ -184,14 +186,14 @@
     _original_weights = weights;
     _is_quantized     = is_data_type_quantized_asymmetric(input->info()->data_type());
     _data_layout      = data_layout;
-    _skip_im2col      = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1) && !_is_quantized;
+    _skip_im2col      = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
+    _skip_col2im      = data_layout == DataLayout::NHWC;
     _append_bias      = (biases != nullptr) && (!_is_quantized);
 
     // Set the GPU target for im2col and col2im
     _im2col_kernel.set_target(CLScheduler::get().target());
     _col2im_kernel.set_target(CLScheduler::get().target());
 
-    bool             is_nhwc                   = _data_layout == DataLayout::NHWC;
     const ICLTensor *gemm_input_to_use         = input;
     ICLTensor       *gemm_output_to_use        = output;
     ICLTensor       *gemm_output_staged_to_use = output;
@@ -241,18 +243,27 @@
     }
 
     // Create GEMM output tensor
-    if(!is_nhwc || _is_quantized)
+    if(!_skip_col2im || _is_quantized)
     {
-        // Calculate GEMM output shape
-        TensorShape shape_gemm = _im2col_output.info()->tensor_shape();
-        shape_gemm.set(0, mat_weights_cols);
-        shape_gemm.set(1, conv_w * conv_h);
-
+        TensorShape shape_gemm;
+        if(_skip_col2im)
+        {
+            shape_gemm = input->info()->tensor_shape();
+            shape_gemm.set(idx_width, conv_w);
+            shape_gemm.set(idx_height, conv_h);
+            shape_gemm.set(idx_channel, mat_weights_cols);
+        }
+        else
+        {
+            shape_gemm = _im2col_output.info()->tensor_shape();
+            shape_gemm.set(0, mat_weights_cols);
+            shape_gemm.set(1, conv_w * conv_h);
+        }
         // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
         const DataType gemm_data_type = _is_quantized ? DataType::S32 : data_type;
         // FIXME: input->clone() doesn't work with subtensors for grouped convolutions.
         TensorInfo info_gemm(shape_gemm, 1, gemm_data_type);
-        info_gemm.set_quantization_info(output->info()->quantization_info());
+        info_gemm.set_quantization_info(output->info()->quantization_info()).set_data_layout(input->info()->data_layout());
         _gemm_output.allocator()->init(info_gemm);
         _memory_group.manage(&_gemm_output);
 
@@ -277,30 +288,29 @@
         int   output_multiplier, output_shift;
         quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
 
-        _memory_group.manage(&_tmp_output);
-        gemm_output_staged_to_use = &_tmp_output;
+        if(!_skip_col2im)
+        {
+            _memory_group.manage(&_tmp_output);
+            gemm_output_staged_to_use = &_tmp_output;
+        }
 
         _gemmlowp_output_stage.configure(gemm_output_to_use, biases, gemm_output_staged_to_use, output_multiplier, output_shift, output_quant_info.offset);
     }
 
-    if(!is_nhwc || _is_quantized)
+    if(!_skip_col2im)
     {
-        if(input->info()->data_layout() == DataLayout::NCHW)
-        {
-            // Configure and tune Col2Im
-            _col2im_kernel.configure(_is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output, Size2D(conv_w, conv_h), num_groups);
-            CLScheduler::get().tune_kernel_static(_col2im_kernel);
-        }
-        else
-        {
-            // Configure reshape layer
-            _reshape_layer.configure(_is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output);
-        }
+        // Configure and tune Col2Im
+        _col2im_kernel.configure(_is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output, Size2D(conv_w, conv_h), num_groups);
+        CLScheduler::get().tune_kernel_static(_col2im_kernel);
     }
 
-    if(!is_nhwc || _is_quantized)
+    if(!_skip_col2im)
     {
         _tmp_output.allocator()->allocate();
+    }
+
+    if(!_skip_col2im || _is_quantized)
+    {
         _gemm_output.allocator()->allocate();
     }
 
@@ -346,10 +356,10 @@
     const ITensorInfo *gemm_output_staged_to_use = output;
     const ITensorInfo *weights_to_use            = weights;
 
-    const bool is_nhwc      = data_layout == DataLayout::NHWC;
     const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
-    const bool skip_im2col  = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1) && !is_quantized;
     const bool append_bias  = (biases != nullptr) && (!is_quantized);
+    const bool skip_im2col  = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
+    const bool skip_col2im  = data_layout == DataLayout::NHWC;
 
     ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(idx_channel) * num_groups) != input->dimension(idx_channel));
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -411,19 +421,30 @@
     }
 
     // Create GEMM output tensor
-    if(!is_nhwc || is_quantized)
+    if(!skip_col2im || is_quantized)
     {
-        TensorShape shape_gemm = gemm_input_to_use->tensor_shape();
-        shape_gemm.set(0, mat_weights_cols);
-        shape_gemm.set(1, conv_w * conv_h);
         const DataType gemm_data_type = is_quantized ? DataType::S32 : data_type;
+        TensorShape    shape_gemm;
+        if(skip_col2im)
+        {
+            shape_gemm = input->tensor_shape();
+            shape_gemm.set(idx_width, conv_w);
+            shape_gemm.set(idx_height, conv_h);
+            shape_gemm.set(idx_channel, mat_weights_cols);
+        }
+        else
+        {
+            shape_gemm = gemm_input_to_use->tensor_shape();
+            shape_gemm.set(0, mat_weights_cols);
+            shape_gemm.set(1, conv_w * conv_h);
+        }
         // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
         info_gemm = TensorInfo(shape_gemm, 1, gemm_data_type);
-        info_gemm.set_quantization_info(output->quantization_info());
+        info_gemm.set_quantization_info(output->quantization_info()).set_data_layout(input->data_layout());
         gemm_output_to_use = &info_gemm;
     }
 
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1, skip_im2col));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, skip_col2im ? conv_h : 1, skip_im2col));
 
     if(is_quantized)
     {
@@ -431,23 +452,22 @@
         int   output_multiplier, output_shift;
         quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
 
-        tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8);
-        tmp_info.set_quantization_info(output->quantization_info());
-        gemm_output_staged_to_use = &tmp_info;
+        if(!skip_col2im)
+        {
+            tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8);
+            tmp_info.set_quantization_info(output->quantization_info());
+            gemm_output_staged_to_use = &tmp_info;
+        }
 
         // Validate output stage for quantized case
-        CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(gemm_output_to_use, biases, gemm_output_staged_to_use, output->quantization_info().offset);
+        CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(gemm_output_to_use, biases, gemm_output_staged_to_use);
     }
 
     // Validate Col2Im
-    if(!is_nhwc || is_quantized)
+    if(!skip_col2im)
     {
-        if(input->data_layout() == DataLayout::NCHW)
-        {
-            ARM_COMPUTE_RETURN_ON_ERROR(CLCol2ImKernel::validate(is_quantized ? gemm_output_staged_to_use : gemm_output_to_use,
-                                                                 output,
-                                                                 Size2D(conv_w, conv_h), num_groups));
-        }
+        ARM_COMPUTE_RETURN_ON_ERROR(CLCol2ImKernel::validate(is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output,
+                                                             Size2D(conv_w, conv_h), num_groups));
     }
 
     //Validate Activation Layer
@@ -492,16 +512,9 @@
     }
 
     // Reshape output matrix
-    if(_data_layout == DataLayout::NCHW || _is_quantized)
+    if(!_skip_col2im)
     {
-        if(_data_layout == DataLayout::NCHW)
-        {
-            CLScheduler::get().enqueue(_col2im_kernel, false);
-        }
-        else
-        {
-            _reshape_layer.run();
-        }
+        CLScheduler::get().enqueue(_col2im_kernel, false);
     }
 
     //Run Activation Layer if enabled
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 1d6f343..62e7ee7 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -108,7 +108,7 @@
     // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
     // in order to know how the matrices have been reshaped
     bool          reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
-    const int     m                         = a->info()->dimension(1);
+    const int     m                         = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
     const int     n                         = b->info()->dimension(0);
     const int     k                         = a->info()->dimension(0);
     const int     depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
@@ -206,12 +206,12 @@
     int32_t a_offset = a->quantization_info().offset;
     int32_t b_offset = b->quantization_info().offset;
 
-    const int     m                         = a->dimension(1);
+    bool          reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
+    const int     m                         = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
     const int     n                         = b->dimension(0);
     const int     k                         = a->dimension(0);
     constexpr int mult_transpose1xW_width   = 1;
     constexpr int mult_interleave4x4_height = 1;
-    bool          reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
     const int     depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
 
     bool reshape_matrices = is_interleaved_transposed(m, n, k, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
diff --git a/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp b/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp
index 16d8678..b18d23f 100644
--- a/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,8 +28,8 @@
 #include "arm_compute/core/CL/kernels/CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.h"
 #include "support/ToolchainSupport.h"
 
-using namespace arm_compute;
-
+namespace arm_compute
+{
 void CLGEMMLowpQuantizeDownInt32ToUint8Scale::configure(const ICLTensor *input, const ICLTensor *bias, ICLTensor *output, int result_offset, int result_mult_int, int result_shift, int min, int max)
 {
     auto k = arm_compute::support::cpp14::make_unique<CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel>();
@@ -42,15 +42,18 @@
     return CLGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::validate(input, bias, output, min, max);
 }
 
-void CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::configure(const ICLTensor *input, const ICLTensor *bias, ICLTensor *output, int result_fixedpoint_multiplier, int result_shift,
-                                                                    int result_offset_after_shift, int min, int max)
+void CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::configure(const ICLTensor *input, const ICLTensor *bias, ICLTensor *output,
+                                                                    int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift,
+                                                                    int min, int max, unsigned int output_3d_depth)
 {
     auto k = arm_compute::support::cpp14::make_unique<CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>();
-    k->configure(input, bias, output, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max);
+    k->configure(input, bias, output, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max, output_3d_depth);
     _kernel = std::move(k);
 }
 
-Status CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max)
+Status CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output,
+                                                                     int min, int max, unsigned int output_3d_depth)
 {
-    return CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, min, max);
-}
\ No newline at end of file
+    return CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, min, max, output_3d_depth);
+}
+} // namespace arm_compute
\ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
index cb70049..d270a77 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
@@ -43,14 +43,14 @@
 }
 
 void NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::configure(const ITensor *input, const ITensor *bias, ITensor *output, int result_fixedpoint_multiplier, int result_shift,
-                                                                    int result_offset_after_shift, int min, int max, unsigned int gemm_3d_depth)
+                                                                    int result_offset_after_shift, int min, int max, unsigned int output_3d_depth)
 {
     auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>();
-    k->configure(input, bias, output, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max, gemm_3d_depth);
+    k->configure(input, bias, output, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max, output_3d_depth);
     _kernel = std::move(k);
 }
 
-Status NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max, unsigned int gemm_3d_depth)
+Status NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max, unsigned int output_3d_depth)
 {
-    return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, min, max, gemm_3d_depth);
+    return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, min, max, output_3d_depth);
 }
\ No newline at end of file