Graph Fusion With Post Ops Fix

- Fusing ConvolutionBatchNormalization Nodes with post ops (activation
or element wise ops)

Resolves: COMPMID-4982
Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com>
Change-Id: I5b2d32cad00f710fd744cb5aa2d59fd7e5c97e0a
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6766
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sheri Zhang <sheri.zhang@arm.com>
diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h
index 1e420a8..a7e52d4 100644
--- a/arm_compute/graph/backends/FunctionHelpers.h
+++ b/arm_compute/graph/backends/FunctionHelpers.h
@@ -32,6 +32,7 @@
 #include "arm_compute/graph/Types.h"
 #include "arm_compute/graph/Utils.h"
 #include "arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h"
+#include "arm_compute/graph/backends/FusedConvolutionBatchNormalizationWithPostOpsFunction.h"
 #include "arm_compute/graph/backends/FusedDepthwiseConvolutionBatchNormalizationFunction.h"
 #include "arm_compute/graph/backends/Utils.h"
 #include "arm_compute/graph/nodes/Nodes.h"
@@ -540,7 +541,7 @@
     return std::move(func);
 }
 
-/** Create a backend convolution layer function with post opreator
+/** Create a backend convolution layer function with post operator
  *
  * @tparam ConvolutionLayerFunctions Backend convolution functions
  * @tparam TargetInfo                Target-specific information
@@ -629,6 +630,91 @@
                                << " Output shape: " << output->info()->tensor_shape()
                                << qss.str()
                                << (fused_act.enabled() ? " " + to_string(fused_act.activation()) : "")
+                               << " Post ops" << post_ops;
+                               << std::endl);
+    return std::move(func);
+}
+
+/** Create a backend convolution batch normalization layer function with post operator
+ *
+ * @tparam FusedLayerTypes           Backend convolution functions
+ * @tparam TargetInfo                Target-specific information
+ *
+ * @param[in] node Node to create the backend function for
+ * @param[in] ctx  Graph context
+ *
+ * @return Backend fused convolution with batch normalization layer function
+ */
+template <typename FusedLayerTypes, typename TargetInfo>
+std::unique_ptr<IFunction> create_fused_convolution_batch_normalization_with_post_op(FusedConvolutionBatchNormalizationWithPostOpsNode &node, GraphContext &ctx)
+{
+    validate_node<TargetInfo>(node, 8 /* expected inputs */, 1 /* expected outputs */);
+
+    // Extract IO and info
+    typename TargetInfo::TensorType *input   = get_backing_tensor<TargetInfo>(node.input(0));
+    typename TargetInfo::TensorType *weights = get_backing_tensor<TargetInfo>(node.input(1));
+    typename TargetInfo::TensorType *biases  = get_backing_tensor<TargetInfo>(node.input(2));
+    typename TargetInfo::TensorType *mean    = get_backing_tensor<TargetInfo>(node.input(3));
+    typename TargetInfo::TensorType *var     = get_backing_tensor<TargetInfo>(node.input(4));
+    typename TargetInfo::TensorType *beta    = get_backing_tensor<TargetInfo>(node.input(5));
+    typename TargetInfo::TensorType *gamma   = get_backing_tensor<TargetInfo>(node.input(6));
+
+    typename TargetInfo::TensorType *output = get_backing_tensor<TargetInfo>(node.output(0));
+
+    const PadStrideInfo conv_info  = node.convolution_info();
+    const unsigned int  num_groups = node.num_groups();
+    const bool          fast_math  = node.fast_math_hint() == FastMathHint::Enabled;
+    const float         epsilon    = node.epsilon();
+
+    experimental::PostOpList<typename TargetInfo::TensorType *> post_ops;
+
+    auto &post_op_info_list = node.post_op_info_list();
+    for(const auto &post_op_info : post_op_info_list)
+    {
+        switch(post_op_info->type())
+        {
+            case PostOpType::Activation:
+            {
+                const auto act_info = utils::cast::polymorphic_downcast<const ConvPostOpInfoActivation *>(post_op_info.get());
+                post_ops.template push_back_op<experimental::PostOpAct<typename TargetInfo::TensorType *>>(act_info->_act);
+                break;
+            }
+            case PostOpType::Eltwise_Add:
+            {
+                typename TargetInfo::TensorType *add_input    = get_backing_tensor<TargetInfo>(node.input(3));
+                const auto                       eltwise_info = utils::cast::polymorphic_downcast<const ConvPostOpInfoEltwiseAdd *>(post_op_info.get());
+                post_ops.template push_back_op<experimental::PostOpEltwiseAdd<typename TargetInfo::TensorType *>>(add_input, eltwise_info->_prev_op_dst_pos, eltwise_info->_policy);
+                break;
+            }
+            default:
+            {
+                ARM_COMPUTE_ERROR("Unsupported PostOpType");
+            }
+        }
+    }
+
+    // Create and configure function (we assume that functions have been validated before creation)
+    std::shared_ptr<IMemoryManager> mm = get_memory_manager(ctx, TargetInfo::TargetType);
+    std::unique_ptr<IFunction>      func;
+    std::string                     func_name;
+
+    using FType = FusedConvolutionBatchNormalizationWithPostOpsFunction<TargetInfo, FusedLayerTypes>;
+
+    // Create and configure function
+    std::tie(func, func_name) = create_named_memory_managed_function<FType>(
+                                    std::string("FusedConvolutionBatchNormalizationLayerWithPostOpsLayer"), mm, input, weights, biases, output, mean, var, beta, gamma, epsilon, conv_info, num_groups, fast_math, post_ops);
+
+    // Log info
+    ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated "
+                               << node.name()
+                               << " Type: " << node.type()
+                               << " Target: " << TargetInfo::TargetType
+                               << " Data Type: " << input->info()->data_type()
+                               << " Input shape: " << input->info()->tensor_shape()
+                               << " Weights shape: " << weights->info()->tensor_shape()
+                               << " Output shape: " << output->info()->tensor_shape()
+                               << (fused_act.enabled() ? " " + to_string(fused_act.activation()) : "")
+                               << " Post Ops:" << post_ops;
                                << std::endl);
     return std::move(func);
 }