COMPMID-3097 Fuse activation with fully connected layer CL

Change-Id: I447030e69b9e565f2f81529a41af8c5e7ece7ecf
Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2702
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index dcaa126..9b7de8d 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -41,7 +41,7 @@
 namespace
 {
 Status construct_gemmlowp_output_stage(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output,
-                                       GEMMLowpOutputStageInfo &gemmlowp_output_stage)
+                                       GEMMLowpOutputStageInfo &gemmlowp_output_stage, ActivationLayerInfo activation_info)
 {
     gemmlowp_output_stage.type                = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
     gemmlowp_output_stage.gemmlowp_offset     = 0;
@@ -53,13 +53,14 @@
     // Configure output stage for quantized case
     if(is_data_type_quantized_asymmetric(data_type))
     {
-        const UniformQuantizationInfo iq_info = input.quantization_info().uniform();
-        const UniformQuantizationInfo wq_info = weights.quantization_info().uniform();
-        const UniformQuantizationInfo oq_info = output.quantization_info().uniform();
+        const QuantizationInfo        oq_info = output.quantization_info();
+        const UniformQuantizationInfo iq_unif = input.quantization_info().uniform();
+        const UniformQuantizationInfo wq_unif = weights.quantization_info().uniform();
+        const UniformQuantizationInfo oq_unif = oq_info.uniform();
 
-        const auto output_quant_info = (output.total_size() == 0) ? iq_info : oq_info;
+        const auto output_quant_info = (output.total_size() == 0) ? iq_unif : oq_unif;
 
-        const float multiplier        = (iq_info.scale * wq_info.scale) / output_quant_info.scale;
+        const float multiplier        = (iq_unif.scale * wq_unif.scale) / output_quant_info.scale;
         int         output_multiplier = 0;
         int         output_shift      = 0;
         ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
@@ -68,6 +69,27 @@
         PixelValue type_max{};
         std::tie(type_min, type_max) = get_min_max(data_type);
 
+        if(activation_info.enabled())
+        {
+            switch(activation_info.activation())
+            {
+                case ActivationLayerInfo::ActivationFunction::RELU:
+                    type_min = PixelValue(oq_unif.offset);
+                    break;
+                case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
+                    type_min = PixelValue(oq_unif.offset);
+                    type_max = PixelValue(activation_info.a(), data_type, oq_info);
+                    break;
+                case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
+                    type_min = PixelValue(activation_info.b(), data_type, oq_info);
+                    type_max = PixelValue(activation_info.a(), data_type, oq_info);
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Activation function not supported.");
+                    break;
+            }
+        }
+
         // Set the GEMMLowp output stage info
         gemmlowp_output_stage.gemmlowp_offset     = output_quant_info.offset;
         gemmlowp_output_stage.gemmlowp_multiplier = output_multiplier;
@@ -84,7 +106,7 @@
 Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo *bias, const ITensorInfo &output, const FullyConnectedLayerInfo &fc_info)
 {
     GEMMLowpOutputStageInfo gemmlowp_output_stage;
-    ARM_COMPUTE_RETURN_ON_ERROR(construct_gemmlowp_output_stage(input, weights, output, gemmlowp_output_stage));
+    ARM_COMPUTE_RETURN_ON_ERROR(construct_gemmlowp_output_stage(input, weights, output, gemmlowp_output_stage, fc_info.activation_info));
 
     const GEMMInfo &gemm_info = GEMMInfo(false,                           // is_a_reshaped
                                          false,                           // is_b_reshaped
@@ -144,7 +166,7 @@
 void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info)
 {
     GEMMLowpOutputStageInfo gemmlowp_output_stage;
-    construct_gemmlowp_output_stage(*input->info(), *weights->info(), *output->info(), gemmlowp_output_stage);
+    construct_gemmlowp_output_stage(*input->info(), *weights->info(), *output->info(), gemmlowp_output_stage, fc_info.activation_info);
 
     const GEMMInfo &gemm_info = GEMMInfo(false,                           // is_a_reshaped
                                          false,                           // is_b_reshaped
@@ -155,7 +177,7 @@
                                          gemmlowp_output_stage,           // gemmlowp_output_stage
                                          fc_info.fp_mixed_precision,      // fp_mixed_precision
                                          true,                            // broadcast_bias
-                                         ActivationLayerInfo());          // activation_info
+                                         fc_info.activation_info);        // activation_info
 
     if(_is_quantized)
     {
@@ -313,6 +335,8 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
+    ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(input->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
+                                && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
 
     bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
     bool is_fc_after_conv = true;