COMPMID-1043: Rework GCGEMMMatrixMultiplyKernel interface and allow auto initialization of the tensors
This patch also:
- removes support for already reshaped weights in GCConvolutionLayer
- makes GCConvolutionLayer similar to CLGEMMConvolutionLayer
- enables usage of the GCGEMM function in GCConvolution instead of calling the
GEMM kernels directly
Change-Id: I3e4a64335555e86e18585d38d8fda4bfdb44e265
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127696
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/GLES_COMPUTE/kernels/GCWeightsReshapeKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCWeightsReshapeKernel.cpp
index 4c08873..55bf9b7 100644
--- a/src/core/GLES_COMPUTE/kernels/GCWeightsReshapeKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCWeightsReshapeKernel.cpp
@@ -31,11 +31,13 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/GLES_COMPUTE/GCHelpers.h"
using namespace arm_compute;
using namespace arm_compute::gles_compute;
+using namespace arm_compute::misc::shape_calculator;
GCWeightsReshapeKernel::GCWeightsReshapeKernel()
: _input(nullptr), _biases(nullptr), _output(nullptr)
@@ -47,15 +49,8 @@
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_NULLPTR(output);
- // Calculate output shape
- TensorShape output_shape{ input->info()->tensor_shape() };
- output_shape.collapse(3);
- const size_t tmp_dim = output_shape[0];
- output_shape.set(0, output_shape[1]);
- output_shape.set(1, tmp_dim + (biases != nullptr ? 1 : 0));
-
// Output tensor auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));
+ auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(compute_weights_reshaped_shape(*input->info(), (biases != nullptr))));
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);