COMPMID-556: Rework CLActivationLayer

Refactoring.

Change-Id: I879353299b655ec3026cccdfcfca2ee98abf14ea
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/94191
Reviewed-by: Michel Iwaniec <michel.iwaniec@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h
index 6e4d987..edb05e9 100644
--- a/arm_compute/core/Helpers.h
+++ b/arm_compute/core/Helpers.h
@@ -466,10 +466,15 @@
  * @param[in]     num_channels         New number of channels.
  * @param[in]     data_type            New data type
  * @param[in]     fixed_point_position New fixed point position
+ * @param[in]     quantization_info    (Optional) New quantization info
  *
  * @return True if the tensor info has been initialized
  */
-bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, int fixed_point_position);
+bool auto_init_if_empty(ITensorInfo       &info,
+                        const TensorShape &shape,
+                        int num_channels, DataType data_type,
+                        int              fixed_point_position,
+                        QuantizationInfo quantization_info = QuantizationInfo());
 
 /* Set the shape to the specified value if the current assignment is empty.
  *
@@ -509,6 +514,17 @@
  * @return True if the fixed point position has been changed.
  */
 bool set_fixed_point_position_if_zero(ITensorInfo &info, int fixed_point_position);
+
+/* Set the quantization info to the specified value if
+ * the current quantization info is empty and the data type of asymmetric quantized type
+ *
+ * @param[in,out] info              Tensor info used to check and assign.
+ * @param[in]     quantization_info Quantization info
+ *
+ * @return True if the quantization info has been changed.
+ */
+bool set_quantization_info_if_empty(ITensorInfo &info, QuantizationInfo quantization_info);
+
 /** Helper function to calculate the Valid Region for Scale.
  *
  * @param[in] src_info         Input tensor info used to check.
@@ -520,6 +536,7 @@
  * @return The corrispondent valid region
  */
 ValidRegion calculate_valid_region_scale(const ITensorInfo &src_info, const TensorShape &dst_shape, InterpolationPolicy policy, BorderSize border_size, bool border_undefined);
+
 /** Convert a linear index into n-dimensional coordinates.
  *
  * @param[in] shape Shape of the n-dimensional tensor.
@@ -528,6 +545,7 @@
  * @return n-dimensional coordinates.
  */
 inline Coordinates index2coords(const TensorShape &shape, int index);
+
 /** Convert n-dimensional coordinates into a linear index.
  *
  * @param[in] shape Shape of the n-dimensional tensor.
diff --git a/arm_compute/core/Helpers.inl b/arm_compute/core/Helpers.inl
index de6c85e..1a27684 100644
--- a/arm_compute/core/Helpers.inl
+++ b/arm_compute/core/Helpers.inl
@@ -197,7 +197,12 @@
     }
 }
 
-inline bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, int fixed_point_position)
+inline bool auto_init_if_empty(ITensorInfo       &info,
+                               const TensorShape &shape,
+                               int                num_channels,
+                               DataType           data_type,
+                               int                fixed_point_position,
+                               QuantizationInfo   quantization_info)
 {
     if(info.tensor_shape().total_size() == 0)
     {
@@ -205,6 +210,7 @@
         info.set_num_channels(num_channels);
         info.set_tensor_shape(shape);
         info.set_fixed_point_position(fixed_point_position);
+        info.set_quantization_info(quantization_info);
         return true;
     }
 
@@ -255,6 +261,17 @@
     return false;
 }
 
+inline bool set_quantization_info_if_empty(ITensorInfo &info, QuantizationInfo quantization_info)
+{
+    if(info.quantization_info().empty() && (is_data_type_assymetric(info.data_type())))
+    {
+        info.set_quantization_info(quantization_info);
+        return true;
+    }
+
+    return false;
+}
+
 inline ValidRegion calculate_valid_region_scale(const ITensorInfo &src_info, const TensorShape &dst_shape, InterpolationPolicy policy, BorderSize border_size, bool border_undefined)
 {
     const auto  wr = static_cast<float>(dst_shape[0]) / static_cast<float>(src_info.tensor_shape()[0]);
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h
index 149e404..8e15a0a 100644
--- a/arm_compute/core/Utils.h
+++ b/arm_compute/core/Utils.h
@@ -708,6 +708,28 @@
     }
 }
 
+/** Check if a given data type is of quantized type
+ *
+ * @note Quantized is considered a super-set of fixed-point and asymmetric data types.
+ *
+ * @param[in] dt Input data type.
+ *
+ * @return True if data type is of quantized type, else false.
+ */
+inline bool is_data_type_quantized(DataType dt)
+{
+    switch(dt)
+    {
+        case DataType::QS8:
+        case DataType::QASYMM8:
+        case DataType::QS16:
+        case DataType::QS32:
+            return true;
+        default:
+            return false;
+    }
+}
+
 /** Check if a given data type is of fixed point type
  *
  * @param[in] dt Input data type.
@@ -727,6 +749,23 @@
     }
 }
 
+/** Check if a given data type is of asymmetric quantized type
+ *
+ * @param[in] dt Input data type.
+ *
+ * @return True if data type is of symmetric quantized type, else false.
+ */
+inline bool is_data_type_assymetric(DataType dt)
+{
+    switch(dt)
+    {
+        case DataType::QASYMM8:
+            return true;
+        default:
+            return false;
+    }
+}
+
 /** Create a string with the float in full precision.
  *
  * @param val Floating point value
diff --git a/src/core/CL/kernels/CLActivationLayerKernel.cpp b/src/core/CL/kernels/CLActivationLayerKernel.cpp
index bed407a..42f577c 100644
--- a/src/core/CL/kernels/CLActivationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLActivationLayerKernel.cpp
@@ -51,18 +51,18 @@
 void CLActivationLayerKernel::configure(ICLTensor *input, ICLTensor *output, ActivationLayerInfo act_info)
 {
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32, DataType::QASYMM8);
-
-    // For QA8 only lower/upper bounded relu is supported
-    if(input->info()->data_type() == DataType::QASYMM8)
-    {
-        ARM_COMPUTE_ERROR_ON_MSG(act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
-                                 "For QASYMM8 only lower/upper bounded relu is supported");
-    }
+    ARM_COMPUTE_ERROR_ON_MSG((input->info()->data_type() == DataType::QASYMM8) && (act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU),
+                             "For QASYMM8 only lower/upper bounded relu is supported");
 
     if(output != nullptr)
     {
         // Output auto inizialitation if not yet initialized
-        auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
+        auto_init_if_empty(*output->info(),
+                           input->info()->tensor_shape(),
+                           1,
+                           input->info()->data_type(),
+                           input->info()->fixed_point_position(),
+                           input->info()->quantization_info());
 
         ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
         ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
@@ -70,63 +70,70 @@
     }
 
     const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size();
+    DataType           dt                                = input->info()->data_type();
     const int          fixed_point_position              = input->info()->fixed_point_position();
     float              a_const                           = act_info.a();
     float              b_const                           = act_info.b();
-    if(is_data_type_fixed_point(input->info()->data_type()))
+    int                a_const_int                       = 0;
+    int                b_const_int                       = 0;
+
+    // Create quantized version of constants a, b if needed
+    if(is_data_type_quantized(dt))
     {
-        a_const = static_cast<int>(lround(a_const * (1 << fixed_point_position)));
-        b_const = static_cast<int>(lround(b_const * (1 << fixed_point_position)));
+        if(is_data_type_fixed_point(dt))
+        {
+            a_const_int = static_cast<int>(lround(a_const * (1 << fixed_point_position)));
+            b_const_int = static_cast<int>(lround(b_const * (1 << fixed_point_position)));
+        }
+        else
+        {
+            a_const_int = input->info()->quantization_info().quantize(a_const);
+            b_const_int = input->info()->quantization_info().quantize(b_const);
+        }
     }
 
     // Set build options
     std::set<std::string> build_opts;
     build_opts.emplace(("-DACT=" + lower_string(string_from_activation_func(act_info.activation()))));
-    build_opts.emplace(("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())));
+    build_opts.emplace(("-DDATA_TYPE=" + get_cl_type_from_data_type(dt)));
     build_opts.emplace(("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration)));
 
-    if(input->info()->data_type() == DataType::QASYMM8)
+    if(is_data_type_quantized(dt))
     {
-        // For lower/upper bounded relu make sure that the min/max values are in the quantized input space
-        int a_const_u8 = input->info()->quantization_info().quantize(a_const);
-        int b_const_u8 = input->info()->quantization_info().quantize(b_const);
+        build_opts.emplace(("-DA_VAL=" + support::cpp11::to_string(a_const_int)));
+        build_opts.emplace(("-DB_VAL=" + support::cpp11::to_string(b_const_int)));
 
-        build_opts.emplace(("-DA_VAL=" + support::cpp11::to_string(a_const_u8)));
-        build_opts.emplace(("-DB_VAL=" + support::cpp11::to_string(b_const_u8)));
+        // Set scale and offset of the input and output
+        if(is_data_type_assymetric(dt))
+        {
+            float s1 = input->info()->quantization_info().scale;
+            int   o1 = input->info()->quantization_info().offset;
+            // If output is nullptr, assume same quantization scale/offset as input
+            float s2 = output != nullptr ? output->info()->quantization_info().scale : s1;
+            int   o2 = output != nullptr ? output->info()->quantization_info().offset : o1;
+            build_opts.emplace(("-DS1_VAL=" + float_to_string_with_full_precision(s1)));
+            build_opts.emplace(("-DS2_VAL=" + float_to_string_with_full_precision(s2)));
+            build_opts.emplace(("-DO1_VAL=" + support::cpp11::to_string(o1)));
+            build_opts.emplace(("-DO2_VAL=" + support::cpp11::to_string(o2)));
+        }
     }
     else
     {
-        build_opts.emplace(("-DA_VAL=" + support::cpp11::to_string(a_const)));
-        build_opts.emplace(("-DB_VAL=" + support::cpp11::to_string(b_const)));
+        build_opts.emplace(("-DA_VAL=" + float_to_string_with_full_precision(a_const)));
+        build_opts.emplace(("-DB_VAL=" + float_to_string_with_full_precision(b_const)));
     }
 
     build_opts.emplace(output == nullptr ? "-DIN_PLACE" : "");
-    if(is_data_type_fixed_point(input->info()->data_type()))
+    if(is_data_type_fixed_point(dt))
     {
         build_opts.emplace(("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(fixed_point_position)));
     }
 
     // Create kernel
-    if(input->info()->data_type() == DataType::QASYMM8)
-    {
-        float s1 = input->info()->quantization_info().scale;
-        float o1 = input->info()->quantization_info().offset;
-        // If output is nullptr, assume same quantization scale/offset as input
-        float s2 = output != nullptr ? output->info()->quantization_info().scale : s1;
-        float o2 = output != nullptr ? output->info()->quantization_info().offset : o1;
-        build_opts.emplace(("-DS1_VAL=" + support::cpp11::to_string(s1)));
-        build_opts.emplace(("-DS2_VAL=" + support::cpp11::to_string(s2)));
-        build_opts.emplace(("-DO1_VAL=" + support::cpp11::to_string(o1)));
-        build_opts.emplace(("-DO2_VAL=" + support::cpp11::to_string(o2)));
-        _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("activation_layer_qa8", build_opts));
-    }
-    else
-    {
-        _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("activation_layer", build_opts));
-    }
+    std::string kernel_name = is_data_type_assymetric(dt) ? std::string("activation_layer_qa8") : std::string("activation_layer");
+    _kernel                 = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts));
 
     // Make sure _kernel is initialized before calling the parent's configure
-
     _input  = input;
     _output = output;