COMPMID-1740: Fuse batch normalization with Convolution Layer at graph level

Change-Id: I77ca51c2c72783cc26a099a6a9c3210cdbbe822d
Signed-off-by: giuros01 <giuseppe.rossini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/797
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h
index 7242bc6..d0035d9 100644
--- a/arm_compute/graph/backends/FunctionHelpers.h
+++ b/arm_compute/graph/backends/FunctionHelpers.h
@@ -28,6 +28,7 @@
 #include "arm_compute/graph/Tensor.h"
 #include "arm_compute/graph/TypePrinter.h"
 #include "arm_compute/graph/Types.h"
+#include "arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h"
 #include "arm_compute/graph/backends/Utils.h"
 #include "arm_compute/graph/nodes/Nodes.h"
 
@@ -135,11 +136,12 @@
     validate_node<TargetInfo>(node, 5 /* expected inputs */, 1 /* expected outputs */);
 
     // Extract IO and info
-    typename TargetInfo::TensorType *input     = get_backing_tensor<TargetInfo>(node.input(0));
-    typename TargetInfo::TensorType *mean      = get_backing_tensor<TargetInfo>(node.input(1));
-    typename TargetInfo::TensorType *var       = get_backing_tensor<TargetInfo>(node.input(2));
-    typename TargetInfo::TensorType *beta      = get_backing_tensor<TargetInfo>(node.input(3));
-    typename TargetInfo::TensorType *gamma     = get_backing_tensor<TargetInfo>(node.input(4));
+    typename TargetInfo::TensorType *input = get_backing_tensor<TargetInfo>(node.input(0));
+    typename TargetInfo::TensorType *mean  = get_backing_tensor<TargetInfo>(node.input(1));
+    typename TargetInfo::TensorType *var   = get_backing_tensor<TargetInfo>(node.input(2));
+    typename TargetInfo::TensorType *beta  = get_backing_tensor<TargetInfo>(node.input(3));
+    typename TargetInfo::TensorType *gamma = get_backing_tensor<TargetInfo>(node.input(4));
+
     typename TargetInfo::TensorType *output    = get_backing_tensor<TargetInfo>(node.output(0));
     const float                      epsilon   = node.epsilon();
     const ActivationLayerInfo        fused_act = node.fused_activation();
@@ -163,6 +165,61 @@
     return std::move(func);
 }
 
+/** Create a backend batch normalization layer function
+ *
+ * @tparam BatchNormalizationLayerFunction Backend batch normalization function
+ * @tparam TargetInfo                      Target-specific information
+ *
+ * @param[in] node Node to create the backend function for
+ *
+ * @return Backend batch normalization layer function
+ */
+template <typename FusedLayerTypes, typename TargetInfo>
+std::unique_ptr<IFunction> create_fused_convolution_batch_normalization_layer(FusedConvolutionBatchNormalizationNode &node)
+{
+    validate_node<TargetInfo>(node, 7 /* expected inputs */, 1 /* expected outputs */);
+
+    // Extract IO and info
+    typename TargetInfo::TensorType *input   = get_backing_tensor<TargetInfo>(node.input(0));
+    typename TargetInfo::TensorType *weights = get_backing_tensor<TargetInfo>(node.input(1));
+    typename TargetInfo::TensorType *biases  = get_backing_tensor<TargetInfo>(node.input(2));
+    typename TargetInfo::TensorType *mean    = get_backing_tensor<TargetInfo>(node.input(3));
+    typename TargetInfo::TensorType *var     = get_backing_tensor<TargetInfo>(node.input(4));
+    typename TargetInfo::TensorType *beta    = get_backing_tensor<TargetInfo>(node.input(5));
+    typename TargetInfo::TensorType *gamma   = get_backing_tensor<TargetInfo>(node.input(6));
+
+    typename TargetInfo::TensorType *output = get_backing_tensor<TargetInfo>(node.output(0));
+
+    const PadStrideInfo       conv_info  = node.convolution_info();
+    const unsigned int        num_groups = node.num_groups();
+    const bool                fast_math  = node.fast_math_hint() == FastMathHint::Enabled;
+    const ActivationLayerInfo fused_act  = node.fused_activation();
+    const float               epsilon    = node.epsilon();
+
+    const bool is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
+    if(is_quantized && biases != nullptr)
+    {
+        biases->info()->set_data_type(DataType::S32);
+    }
+
+    // Create and configure function
+    auto func = support::cpp14::make_unique<FusedConvolutionBatchNormalizationFunction<TargetInfo, FusedLayerTypes>>();
+    func->configure(input, weights, biases, output, mean, var, beta, gamma, epsilon, conv_info, num_groups, fast_math, fused_act);
+
+    // Log info
+    ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated "
+                               << node.name()
+                               << " Type: " << node.name()
+                               << " Target: " << TargetInfo::TargetType
+                               << " Data Type: " << input->info()->data_type()
+                               << " Input shape: " << input->info()->tensor_shape()
+                               << " Weights shape: " << weights->info()->tensor_shape()
+                               << " Output shape: " << output->info()->tensor_shape()
+                               << (fused_act.enabled() ? " " + to_string(fused_act.activation()) : "")
+                               << std::endl);
+    return std::move(func);
+}
+
 /** Create a backend bounding box transform layer function
  *
  * @tparam BoundingBoxTransformLayerFunction    Backend bounding box transform function