COMPMID-1751: Remove output_3d_depth from NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint

Change-Id: I1d5bc4d24059917f9ddef0873dd3043b1f2320a8
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 82b9cb8..72a3e80 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -139,7 +139,7 @@
     ARM_COMPUTE_UNUSED(alpha);
 
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b, output);
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index d02c63c..2433201 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -101,6 +101,9 @@
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
     ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info(), gemm_3d_depth, _skip_im2col));
 
+    const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
+                                         gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
+
     if(_is_quantized)
     {
         // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -111,7 +114,7 @@
         input->info()->set_quantization_info(QuantizationInfo(input_quantization_info.scale, -input_quantization_info.offset));
         weights->info()->set_quantization_info(QuantizationInfo(weights_quantization_info.scale, -weights_quantization_info.offset));
 
-        _mm_gemmlowp.configure(input, weights, nullptr, output, GEMMInfo(false, false, true /* Reshape weights only for the first run*/));
+        _mm_gemmlowp.configure(input, weights, nullptr, output, gemm_info);
 
         // Revert back QuantizatioInfo as input and weights could be used in other convolution layers
         input->info()->set_quantization_info(input_quantization_info);
@@ -120,8 +123,7 @@
     else
     {
         // Configure matrix multiply function
-        _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth,
-                                                                                 _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */));
+        _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, gemm_info);
     }
 }
 
@@ -129,7 +131,8 @@
 {
     const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
 
-    const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col);
+    const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
+                                         gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
     if(is_quantized)
     {
         // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -256,15 +259,24 @@
     }
 
     // Create temporary GEMM output tensor in case we cannot skip col2im
-    if(!_skip_col2im)
+    if(!_skip_col2im || _is_quantized)
     {
-        // Calculate GEMM output shape
-        TensorShape shape_gemm = _im2col_output.info()->tensor_shape();
-        shape_gemm.set(0, mat_weights_cols);
-        shape_gemm.set(1, conv_w * conv_h);
-
         // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
         const DataType gemm_data_type = _is_quantized ? DataType::S32 : data_type;
+        TensorShape    shape_gemm;
+
+        if(_is_quantized && _skip_col2im)
+        {
+            shape_gemm = output->info()->tensor_shape();
+        }
+        else
+        {
+            // Calculate GEMM output shape
+            shape_gemm = _im2col_output.info()->tensor_shape();
+            shape_gemm.set(0, mat_weights_cols);
+            shape_gemm.set(1, conv_w * conv_h);
+        }
+
         // FIXME: input->clone() doesn't work with subtensors for grouped convolutions.
         TensorInfo info_gemm(shape_gemm, 1, gemm_data_type);
         info_gemm.set_quantization_info(output->info()->quantization_info()).set_data_layout(input->info()->data_layout());
@@ -321,8 +333,7 @@
             _is_activationlayer_enabled = false;
         }
 
-        _gemmlowp_output_stage.configure(gemm_output_to_use, biases, gemm_output_staged_to_use, output_multiplier, output_shift, output_quant_info.offset, min_activation, max_activation,
-                                         skip_reshape ? conv_h : 1);
+        _gemmlowp_output_stage.configure(gemm_output_to_use, biases, gemm_output_staged_to_use, output_multiplier, output_shift, output_quant_info.offset, min_activation, max_activation);
     }
 
     if(!_skip_col2im && _data_layout == DataLayout::NCHW)
@@ -336,7 +347,7 @@
         _tmp_output.allocator()->allocate();
     }
 
-    if(!_skip_col2im)
+    if(!_skip_col2im || _is_quantized)
     {
         _gemm_output.allocator()->allocate();
     }
@@ -464,18 +475,20 @@
     }
 
     // Create temporary GEMM output tensor in case we cannot skip col2im
+    const DataType gemm_data_type = is_quantized ? DataType::S32 : data_type;
     if(!skip_col2im)
     {
         TensorShape shape_gemm = gemm_input_to_use->tensor_shape();
         shape_gemm.set(0, mat_weights_cols);
         shape_gemm.set(1, conv_w * conv_h);
-        const DataType gemm_data_type = is_quantized ? DataType::S32 : data_type;
-        // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
         info_gemm = TensorInfo(shape_gemm, 1, gemm_data_type);
-        info_gemm.set_quantization_info(output->quantization_info()).set_data_layout(input->data_layout());
-
-        gemm_output_to_use = &info_gemm;
     }
+    else
+    {
+        info_gemm = TensorInfo(output->tensor_shape(), 1, gemm_data_type);
+    }
+    info_gemm.set_quantization_info(output->quantization_info()).set_data_layout(input->data_layout());
+    gemm_output_to_use = &info_gemm;
 
     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, skip_col2im ? conv_h : 0, skip_im2col));
 
@@ -516,7 +529,7 @@
         }
 
         // Validate output stage for quantized case
-        NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(gemm_output_to_use, biases, gemm_output_staged_to_use, min_activation, max_activation, skip_reshape ? conv_h : 0);
+        NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(gemm_output_to_use, biases, gemm_output_staged_to_use, min_activation, max_activation);
     }
 
     // Validate Col2Im/ReshapeLayer
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index 16ee3d0..4b02694 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -190,42 +190,68 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(c != nullptr, "Bias addition not supported in NEGEMMLowpMatrixMultiplyCore");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG((a)->dimension(0) != (b)->dimension(1),
                                     "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG((a)->dimension(1) != (output)->dimension(1),
-                                    "The output matrix must have the same number of rows as the matrix A");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG((b)->dimension(0) != (output)->dimension(0),
-                                    "The output matrix must have the same number of columns as the matrix B");
-    ARM_COMPUTE_UNUSED(gemm_info);
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.reinterpret_input_as_3d(), "NEGEMMLowpMatrixMultiplyCore cannot reinterpret the input tensor as 3D");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 0, "NEGEMMLowpMatrixMultiplyCore cannot reinterpret the output tensor as 3D");
 
-    int32_t a_offset                         = a->quantization_info().offset;
-    int32_t b_offset                         = b->quantization_info().offset;
-    bool    run_vector_matrix_multiplication = a->dimension(1) < 2;
+    int32_t    a_offset                    = a->quantization_info().offset;
+    int32_t    b_offset                    = b->quantization_info().offset;
+    const bool reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
 
-    if(!run_vector_matrix_multiplication)
+    // Check if we need to run the optimized assembly kernel
+    const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, output, 1.f, 0.f, reshape_b_only_on_first_run));
+
+    if(run_optimised)
     {
-        // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ]
-        TensorShape shape_tmp_a = a->tensor_shape();
-        shape_tmp_a.set(0, a->dimension(0) * 4);
-        shape_tmp_a.set(1, std::ceil(a->dimension(1) / 4.f));
-
-        // The transpose1xW output matrix will have the following shape: [ b_height * 16, ceil(b_width / 16.0f) ]
-        TensorShape shape_tmp_b = b->tensor_shape();
-        shape_tmp_b.set(0, b->dimension(1) * 16);
-        shape_tmp_b.set(1, std::ceil(b->dimension(0) / 16.f));
-
-        TensorInfo info_a(shape_tmp_a, 1, a->data_type());
-        TensorInfo info_b(shape_tmp_b, 1, b->data_type());
-
-        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, &info_a));
-        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, &info_b));
-        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyKernel::validate(&info_a, &info_b, output));
+        if(output->total_size() != 0)
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != output->dimension(0));
+            if(gemm_info.depth_output_gemm3d() != 0)
+            {
+                if(gemm_info.reinterpret_input_as_3d())
+                {
+                    ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1));
+                    ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(2) != output->dimension(2));
+                }
+                else
+                {
+                    ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1) * output->dimension(2));
+                }
+            }
+            else
+            {
+                ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1));
+            }
+        }
     }
     else
     {
-        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyKernel::validate(a, b, output));
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.reinterpret_input_as_3d(), "NEGEMM cannot reinterpret the input tensor as 3D");
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 0, "NEGEMM cannot reinterpret the output tensor as 3D");
+
+        const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
+        if(!run_vector_matrix_multiplication)
+        {
+            // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ]
+            TensorShape shape_tmp_a = a->tensor_shape();
+            shape_tmp_a.set(0, a->dimension(0) * 4);
+            shape_tmp_a.set(1, std::ceil(a->dimension(1) / 4.f));
+
+            // The transpose1xW output matrix will have the following shape: [ b_height * 16, ceil(b_width / 16.0f) ]
+            TensorShape shape_tmp_b = b->tensor_shape();
+            shape_tmp_b.set(0, b->dimension(1) * 16);
+            shape_tmp_b.set(1, std::ceil(b->dimension(0) / 16.f));
+
+            TensorInfo info_a(shape_tmp_a, 1, a->data_type());
+            TensorInfo info_b(shape_tmp_b, 1, b->data_type());
+
+            ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, &info_a));
+            ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, &info_b));
+            ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyKernel::validate(&info_a, &info_b, output));
+        }
+        else
+        {
+            ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyKernel::validate(a, b, output));
+        }
     }
 
     TensorInfo info_vector_sum_col, info_vector_sum_row;
diff --git a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
index d270a77..ce69fa0 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
@@ -43,14 +43,14 @@
 }
 
 void NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::configure(const ITensor *input, const ITensor *bias, ITensor *output, int result_fixedpoint_multiplier, int result_shift,
-                                                                    int result_offset_after_shift, int min, int max, unsigned int output_3d_depth)
+                                                                    int result_offset_after_shift, int min, int max)
 {
     auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>();
-    k->configure(input, bias, output, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max, output_3d_depth);
+    k->configure(input, bias, output, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max);
     _kernel = std::move(k);
 }
 
-Status NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max, unsigned int output_3d_depth)
+Status NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max)
 {
-    return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, min, max, output_3d_depth);
+    return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, min, max);
 }
\ No newline at end of file