COMPMID-3236: Add support QSYMM16 into S32 CLPixelwiseMultiplicationKernel

Change-Id: Ifc519f53f04fcb14ddb9c17f98cc687f34285c97
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3018
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/pixelwise_mul_float.cl b/src/core/CL/cl_kernels/pixelwise_mul_float.cl
index aad4bec..163cb23 100644
--- a/src/core/CL/cl_kernels/pixelwise_mul_float.cl
+++ b/src/core/CL/cl_kernels/pixelwise_mul_float.cl
@@ -30,7 +30,7 @@
 #endif /* SATURATE */
 #define CONVERT_OP_FLOAT(x, type, round) CONVERT_OP_FLOAT_STR(x, type, round)
 
-#if defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(DATA_TYPE_RES) && defined(DATA_TYPE_OUT)
+#if defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(ACC_DATA_TYPE) && defined(DATA_TYPE_OUT)
 
 #if defined(ACTIVATION_TYPE)
 #include "activation_float_helpers.h"
@@ -40,8 +40,8 @@
  *
  * @attention The inputs and output data types need to be passed at compile time using -DDATA_TYPE_IN1, -DDATA_TYPE_IN2 and -DDATA_TYPE_OUT:
  * e.g. -DDATA_TYPE_IN1=uchar -DDATA_TYPE_IN2=ushort -DDATA_TYPE_OUT=short
- * @attention The data type of the intermediate result of the multiplication should passed as well using -DDATA_TYPE_RES.
- * e.g. If one of inputs is S16 -DDATA_TYPE_RES=int should be passed else -DDATA_TYPE_RES=short.
+ * @attention The data type of the intermediate result of the multiplication should passed as well using -DACC_DATA_TYPE.
+ * e.g. If one of inputs is S16 -DACC_DATA_TYPE=int should be passed else -DACC_DATA_TYPE=short.
  * @attention -DDATA_TYPE_FLOAT must be passed if floating point inputs are provided.
  *
  * @param[in]  in1_ptr                           Pointer to the source image. Supported data types: U8, S16, F16, F32
@@ -82,18 +82,18 @@
     Tensor3D out = CONVERT_TO_TENSOR3D_STRUCT(out);
 
     // Load data
-    VEC_DATA_TYPE(DATA_TYPE_RES, 16)
-    in1_data = CONVERT(vload16(0, (__global DATA_TYPE_IN1 *)in1.ptr), VEC_DATA_TYPE(DATA_TYPE_RES, 16));
-    VEC_DATA_TYPE(DATA_TYPE_RES, 16)
-    in2_data = CONVERT(vload16(0, (__global DATA_TYPE_IN2 *)in2.ptr), VEC_DATA_TYPE(DATA_TYPE_RES, 16));
+    VEC_DATA_TYPE(ACC_DATA_TYPE, 16)
+    in1_data = CONVERT(vload16(0, (__global DATA_TYPE_IN1 *)in1.ptr), VEC_DATA_TYPE(ACC_DATA_TYPE, 16));
+    VEC_DATA_TYPE(ACC_DATA_TYPE, 16)
+    in2_data = CONVERT(vload16(0, (__global DATA_TYPE_IN2 *)in2.ptr), VEC_DATA_TYPE(ACC_DATA_TYPE, 16));
 
     // Perform multiplication
 #ifdef DATA_TYPE_FLOAT
     VEC_DATA_TYPE(DATA_TYPE_OUT, 16)
-    res = CONVERT(in1_data * in2_data * (DATA_TYPE_RES)scale, VEC_DATA_TYPE(DATA_TYPE_OUT, 16));
+    res = CONVERT(in1_data * in2_data * (ACC_DATA_TYPE)scale, VEC_DATA_TYPE(DATA_TYPE_OUT, 16));
 #else  /* DATA_TYPE_FLOAT */
     VEC_DATA_TYPE(DATA_TYPE_OUT, 16)
-    res = CONVERT_OP_FLOAT(CONVERT_OP_FLOAT((convert_float16(in1_data * in2_data) * scale), VEC_DATA_TYPE(DATA_TYPE_RES, 16), ROUND), VEC_DATA_TYPE(DATA_TYPE_OUT, 16), ROUND);
+    res = CONVERT_OP_FLOAT(CONVERT_OP_FLOAT((convert_float16(in1_data * in2_data) * scale), VEC_DATA_TYPE(ACC_DATA_TYPE, 16), ROUND), VEC_DATA_TYPE(DATA_TYPE_OUT, 16), ROUND);
 #endif /* DATA_TYPE_FLOAT */
 
 #if defined(ACTIVATION_TYPE)
@@ -103,7 +103,7 @@
     vstore16(res, 0, (__global DATA_TYPE_OUT *)out.ptr);
 #endif // defined(ACTIVATION_TYPE)
 }
-#endif /* defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(DATA_TYPE_RES) && defined(DATA_TYPE_OUT) */
+#endif /* defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(ACC_DATA_TYPE) && defined(DATA_TYPE_OUT) */
 
 /** Performs a pixelwise multiplication of complex float values
  *
diff --git a/src/core/CL/cl_kernels/pixelwise_mul_int.cl b/src/core/CL/cl_kernels/pixelwise_mul_int.cl
index d277c6c..097df82 100644
--- a/src/core/CL/cl_kernels/pixelwise_mul_int.cl
+++ b/src/core/CL/cl_kernels/pixelwise_mul_int.cl
@@ -35,13 +35,13 @@
 #define CONVERT_RTE(x, type) (convert_##type##_rte((x)))
 #define CONVERT_DOWN(x, type) CONVERT_RTE(x, type)
 
-#if defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(DATA_TYPE_RES) && defined(DATA_TYPE_OUT)
+#if defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(ACC_DATA_TYPE) && defined(DATA_TYPE_OUT)
 /** Performs a pixelwise multiplication with integer scale of integer inputs.
  *
  * @attention The inputs and output data types need to be passed at compile time using -DDATA_TYPE_IN1, -DDATA_TYPE_IN2 and -DDATA_TYPE_OUT:
  * e.g. -DDATA_TYPE_IN1=uchar -DDATA_TYPE_IN2=ushort -DDATA_TYPE_OUT=short
- * @attention The data_type of the intermediate result of the multiplication should passed as well using -DDATA_TYPE_RES.
- * e.g. If one of inputs is S16 -DDATA_TYPE_RES=int should be passed else -DDATA_TYPE_RES=short.
+ * @attention The data_type of the intermediate result of the multiplication should passed as well using -DACC_DATA_TYPE.
+ * e.g. If one of inputs is S16 -DACC_DATA_TYPE=int should be passed else -DACC_DATA_TYPE=short.
  *
  * @param[in]  in1_ptr                           Pointer to the source image. Supported data types: U8/S16
  * @param[in]  in1_stride_x                      Stride of the source image in X dimension (in bytes)
@@ -81,15 +81,15 @@
     Tensor3D out = CONVERT_TO_TENSOR3D_STRUCT(out);
 
     // Load data
-    VEC_DATA_TYPE(DATA_TYPE_RES, 16)
-    in1_data = CONVERT(vload16(0, (__global DATA_TYPE_IN1 *)in1.ptr), VEC_DATA_TYPE(DATA_TYPE_RES, 16));
-    VEC_DATA_TYPE(DATA_TYPE_RES, 16)
-    in2_data = CONVERT(vload16(0, (__global DATA_TYPE_IN2 *)in2.ptr), VEC_DATA_TYPE(DATA_TYPE_RES, 16));
+    VEC_DATA_TYPE(ACC_DATA_TYPE, 16)
+    in1_data = CONVERT(vload16(0, (__global DATA_TYPE_IN1 *)in1.ptr), VEC_DATA_TYPE(ACC_DATA_TYPE, 16));
+    VEC_DATA_TYPE(ACC_DATA_TYPE, 16)
+    in2_data = CONVERT(vload16(0, (__global DATA_TYPE_IN2 *)in2.ptr), VEC_DATA_TYPE(ACC_DATA_TYPE, 16));
 
     // Perform multiplication and store result
     vstore16(MUL_OP(in1_data, in2_data, scale, DATA_TYPE_OUT, 16), 0, (__global DATA_TYPE_OUT *)out.ptr);
 }
-#endif /* defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(DATA_TYPE_RES) && defined(DATA_TYPE_OUT) */
+#endif /* defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(ACC_DATA_TYPE) && defined(DATA_TYPE_OUT) */
 
 #if defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE_OUT) && defined(VEC_SIZE)
 
diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
index ff5afa3..2df3ff4 100644
--- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
+++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
@@ -28,17 +28,9 @@
 #include "arm_compute/core/CL/CLValidate.h"
 #include "arm_compute/core/CL/ICLTensor.h"
 #include "arm_compute/core/CL/OpenCL.h"
-#include "arm_compute/core/Error.h"
-#include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/TensorInfo.h"
-#include "arm_compute/core/Window.h"
 #include "support/StringSupport.h"
 
-#include <cmath>
-#include <cstdlib>
-#include <set>
-#include <string>
-
 namespace arm_compute
 {
 namespace
@@ -77,7 +69,7 @@
                                                              1,
                                                              DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
                                                              DataType::S16, DataType::QSYMM16, DataType::F16,
-                                                             DataType::F32);
+                                                             DataType::S32, DataType::F32);
         ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
                                         "Output can only be U8 if both inputs are U8");
         ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QASYMM8 && (input1->data_type() != DataType::QASYMM8 || input2->data_type() != DataType::QASYMM8),
@@ -86,6 +78,8 @@
                                         "Output can only be QASYMM8_SIGNED if both inputs are QASYMM8_SIGNED");
         ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QSYMM16 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16),
                                         "Output can only be QSYMM16 if both inputs are QSYMM16");
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16),
+                                        "Output can only be S32 if both inputs are QSYMM16");
         ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
     }
 
@@ -177,22 +171,24 @@
         scale_int = std::abs(exponent - 1);
     }
 
-    std::string compute_type;
+    std::string acc_type;
     // Check if it has float inputs and output
     if(is_data_type_float(input1->info()->data_type()) || is_data_type_float(input2->info()->data_type()))
     {
-        scale_int    = -1;
-        compute_type = (input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32) ? "float" : "half";
+        scale_int = -1;
+        acc_type  = (input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32) ? "float" : "half";
     }
     else
     {
-        if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
+        if(input1->info()->element_size() == 2 || input2->info()->element_size() == 2)
         {
-            compute_type = "int";
+            // Use 32-bit accumulator for 16-bit input
+            acc_type = "int";
         }
         else
         {
-            compute_type = "ushort";
+            // Use 16-bit accumulator for 8-bit input
+            acc_type = "ushort";
         }
     }
 
@@ -205,7 +201,7 @@
     build_opts.add_option("-DDATA_TYPE_IN2=" + get_cl_type_from_data_type(input2->info()->data_type()));
     build_opts.add_option("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(output->info()->data_type()));
     build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration));
-    if(is_quantized)
+    if(is_quantized && (output->info()->data_type() != DataType::S32))
     {
         const UniformQuantizationInfo iq1_info = input1->info()->quantization_info().uniform();
         const UniformQuantizationInfo iq2_info = input2->info()->quantization_info().uniform();
@@ -227,7 +223,7 @@
         kernel_name += (scale_int >= 0) ? "_int" : "_float";
         build_opts.add_option_if_else(overflow_policy == ConvertPolicy::WRAP || is_data_type_float(output->info()->data_type()), "-DWRAP", "-DSATURATE");
         build_opts.add_option_if_else(rounding_policy == RoundingPolicy::TO_ZERO, "-DROUND=_rtz", "-DROUND=_rte");
-        build_opts.add_option("-DDATA_TYPE_RES=" + compute_type);
+        build_opts.add_option("-DACC_DATA_TYPE=" + acc_type);
         if(act_info.enabled())
         {
             build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(act_info.activation())));