COMPMID-3101 Fuse activation with floating point elementwise operation layers in CL

Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Change-Id: I1693f8664ba7c0dc8c076bbe7365cef1e667bd25
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2718
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/graph/LayerDescriptors.h b/arm_compute/graph/LayerDescriptors.h
index 0cf2031..d8e6a6a 100644
--- a/arm_compute/graph/LayerDescriptors.h
+++ b/arm_compute/graph/LayerDescriptors.h
@@ -70,20 +70,23 @@
 {
     /** Constructor
      *
-     * @param[in] op             Element-wise operation to perform
-     * @param[in] out_quant_info (Optional) Output quantization information. Defaults to empty @ref QuantizationInfo
-     * @param[in] c_policy       (Optional) Convert policy used for the operation. Defaults to @ref ConvertPolicy::SATURATE
-     * @param[in] r_policy       (Optional) Rounding policy used for the operation. Defaults to @ref RoundingPolicy::TO_ZERO
+     * @param[in] op               Element-wise operation to perform
+     * @param[in] out_quant_info   (Optional) Output quantization information. Defaults to empty @ref QuantizationInfo
+     * @param[in] c_policy         (Optional) Convert policy used for the operation. Defaults to @ref ConvertPolicy::SATURATE
+     * @param[in] r_policy         (Optional) Rounding policy used for the operation. Defaults to @ref RoundingPolicy::TO_ZERO
+     * @param[in] fused_activation (Optional) Fused activation information. Defaults to empty (identity) @ref ActivationLayerInfo
      */
-    EltwiseLayerDescriptor(EltwiseOperation op, QuantizationInfo out_quant_info = QuantizationInfo(), ConvertPolicy c_policy = ConvertPolicy::SATURATE, RoundingPolicy r_policy = RoundingPolicy::TO_ZERO)
-        : op(op), out_quant_info(out_quant_info), c_policy(c_policy), r_policy(r_policy)
+    EltwiseLayerDescriptor(EltwiseOperation op, QuantizationInfo out_quant_info = QuantizationInfo(), ConvertPolicy c_policy = ConvertPolicy::SATURATE, RoundingPolicy r_policy = RoundingPolicy::TO_ZERO,
+                           ActivationLayerInfo fused_activation = ActivationLayerInfo())
+        : op(op), out_quant_info(out_quant_info), c_policy(c_policy), r_policy(r_policy), fused_activation(fused_activation)
     {
     }
 
-    EltwiseOperation op;             /**< Element-wise operation to perform */
-    QuantizationInfo out_quant_info; /**< Output quantization information */
-    ConvertPolicy    c_policy;       /**< Convert policy */
-    RoundingPolicy   r_policy;       /**< Rounding policy */
+    EltwiseOperation    op;               /**< Element-wise operation to perform */
+    QuantizationInfo    out_quant_info;   /**< Output quantization information */
+    ConvertPolicy       c_policy;         /**< Convert policy */
+    RoundingPolicy      r_policy;         /**< Rounding policy */
+    ActivationLayerInfo fused_activation; /**< Fused activation info */
 };
 
 /** Deconvolution layer descriptor */
diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h
index 44b24b5..382b18a 100644
--- a/arm_compute/graph/backends/FunctionHelpers.h
+++ b/arm_compute/graph/backends/FunctionHelpers.h
@@ -773,6 +773,7 @@
     typename TargetInfo::TensorType *output         = get_backing_tensor<TargetInfo>(node.output(0));
     const EltwiseOperation           eltwise_op     = node.eltwise_operation();
     const ConvertPolicy              convert_policy = node.convert_policy();
+    const ActivationLayerInfo        act_info       = node.fused_activation();
     ARM_COMPUTE_ERROR_ON(input1 == nullptr);
     ARM_COMPUTE_ERROR_ON(input2 == nullptr);
     ARM_COMPUTE_ERROR_ON(output == nullptr);
@@ -783,19 +784,19 @@
     {
         std::tie(func, func_name) = create_named_function<typename EltwiseFunctions::Addition>(
                                         std::string("ArithmeticAddition"),
-                                        input1, input2, output, convert_policy);
+                                        input1, input2, output, convert_policy, act_info);
     }
     else if(eltwise_op == EltwiseOperation::Sub)
     {
         std::tie(func, func_name) = create_named_function<typename EltwiseFunctions::Subtraction>(
                                         std::string("ArithmeticSubtraction"),
-                                        input1, input2, output, convert_policy);
+                                        input1, input2, output, convert_policy, act_info);
     }
     else if(eltwise_op == EltwiseOperation::Mul)
     {
         std::tie(func, func_name) = create_named_function<typename EltwiseFunctions::Multiplication>(
                                         std::string("PixelWiseMultiplication"),
-                                        input1, input2, output, 1.f, convert_policy, node.rounding_policy());
+                                        input1, input2, output, 1.f, convert_policy, node.rounding_policy(), act_info);
     }
     else
     {
diff --git a/arm_compute/graph/nodes/EltwiseLayerNode.h b/arm_compute/graph/nodes/EltwiseLayerNode.h
index 21c220a..d619ad2 100644
--- a/arm_compute/graph/nodes/EltwiseLayerNode.h
+++ b/arm_compute/graph/nodes/EltwiseLayerNode.h
@@ -57,12 +57,26 @@
      */
     RoundingPolicy rounding_policy() const;
 
+    /** Returns fused activation
+     *
+     * @return Fused activation
+     */
+    ActivationLayerInfo fused_activation() const;
+
+    /** Sets fused activation
+     *
+     * @param[in] fused_activation Fused activation to set
+     */
+    void set_fused_activation(ActivationLayerInfo fused_activation);
+
     // Inherited overridden methods:
     NodeType         type() const override;
     bool             forward_descriptors() override;
     TensorDescriptor configure_output(size_t idx) const override;
     void accept(INodeVisitor &v) override;
 
+    static constexpr NodeType node_type = NodeType::EltwiseLayer;
+
 private:
     descriptors::EltwiseLayerDescriptor descriptor;
 };