COMPMID-1505: Add native grouping support at graph level

Change-Id: Iedc91b0aee743b59af5140c8acb8124548da3163
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144362
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
diff --git a/src/graph/mutators/SplitLayerSubTensorMutator.cpp b/src/graph/mutators/SplitLayerSubTensorMutator.cpp
index 2a8c029..5f1c9c3 100644
--- a/src/graph/mutators/SplitLayerSubTensorMutator.cpp
+++ b/src/graph/mutators/SplitLayerSubTensorMutator.cpp
@@ -25,6 +25,7 @@
 
 #include "arm_compute/graph/Graph.h"
 #include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/algorithms/TopologicalSort.h"
 #include "arm_compute/graph/backends/BackendRegistry.h"
 #include "arm_compute/graph/nodes/SplitLayerNode.h"
 
@@ -42,10 +43,20 @@
 
 void SplitLayerSubTensorMutator::mutate(Graph &g)
 {
-    // Should be in reverse order of execution
-    for(auto &node : arm_compute::utils::iterable::reverse_iterate(g.nodes()))
+    // Early exit if no Split layers exist in graph
+    if(g.nodes(NodeType::SplitLayer).empty())
     {
-        if(node && node->type() == NodeType::SplitLayer && node->input(0) != nullptr)
+        return;
+    }
+
+    // Perform topological sort
+    std::vector<NodeID> topological_sorted_node_ids = dfs(g);
+
+    // Should be in reverse order of execution
+    for(auto &node_id : arm_compute::utils::iterable::reverse_iterate(topological_sorted_node_ids))
+    {
+        INode *node = g.node(node_id);
+        if(node != nullptr && node->type() == NodeType::SplitLayer && node->input(0) != nullptr)
         {
             // Get output tensor
             Tensor *input_tensor = node->input(0);
@@ -63,7 +74,7 @@
                 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Using sub-tensors for the node with ID : "
                                               << node->id() << " and name : " << node->name() << std::endl);
 
-                auto *split_node = arm_compute::utils::cast::polymorphic_downcast<SplitLayerNode *>(node.get());
+                auto *split_node = arm_compute::utils::cast::polymorphic_downcast<SplitLayerNode *>(node);
 
                 const unsigned int axis          = split_node->axis();
                 const unsigned int num_splits    = split_node->num_splits();