COMPMID-1485 - Add support for NHWC when running NEGEMMConvolutionLayer with FP16/QASYMM8
When the GEMM3D check fails, now we fallback to the classic implementation with im2col
and col2im. In this manner the function can work with QASYMM8 and FP16
Change-Id: I359e9da3a63956f33b5acbc9bca4383b14af10e2
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/143372
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
index a362a29..e587cb4 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
@@ -37,6 +37,7 @@
#include "arm_compute/runtime/NEON/functions/NEGEMM.h"
#include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h"
#include "arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h"
+#include "arm_compute/runtime/NEON/functions/NEReshapeLayer.h"
#include "arm_compute/runtime/Tensor.h"
#include <memory>
@@ -84,7 +85,7 @@
* -# @ref NEGEMMLowpMatrixMultiplyCore (if the data type is QASYMM8)
* -# @ref NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint (if the data type is QASYMM8)
* -# @ref NEArithmeticAdditionKernel (if biases != nullptr and we have a 1x1 convolution with the NHWC data layout)
- * -# @ref NECol2ImKernel
+ * -# @ref NECol2ImKernel or @ref NEReshapeLayer (if NHWC and GEMM3D is not supported)
*
*/
class NEGEMMConvolutionLayer : public IFunction
@@ -165,6 +166,15 @@
* @return a status
*/
static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth = 1, bool skip_im2col = false);
+ /** Static function to check if GEMM3D is supported in @ref NEGEMM or in @ref NEGEMMLowpMatrixMultiplyCore
+ *
+ * @param[in] data_type Input data type
+ * @param[in] gemm_3d_depth Depth of GEMM 3D
+ * @param[in] skip_im2col Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout
+ *
+ * @return a status
+ */
+ static Status validate_gemm3d(DataType data_type, int gemm_3d_depth, bool skip_im2col);
private:
MemoryGroup _memory_group;
@@ -176,6 +186,7 @@
NECol2ImKernel _col2im_kernel;
NEActivationLayer _activationlayer_function;
NEArithmeticAdditionKernel _add_bias_kernel;
+ NEReshapeLayer _reshape_layer;
const ITensor *_original_weights;