COMPMID-1367: Enable NHWC in graph examples

Change-Id: Iabc54a3a1bdcd46a9a921cda39c7c85fef672b72
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141449
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
diff --git a/src/graph/mutators/DepthConcatSubTensorMutator.cpp b/src/graph/mutators/DepthConcatSubTensorMutator.cpp
index c56f4c5..241c07b 100644
--- a/src/graph/mutators/DepthConcatSubTensorMutator.cpp
+++ b/src/graph/mutators/DepthConcatSubTensorMutator.cpp
@@ -25,8 +25,9 @@
 
 #include "arm_compute/graph/Graph.h"
 #include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/Utils.h"
 #include "arm_compute/graph/backends/BackendRegistry.h"
-#include "arm_compute/graph/nodes/DepthConcatenateLayerNode.h"
+#include "arm_compute/graph/nodes/ConcatenateLayerNode.h"
 
 #include "arm_compute/core/utils/misc/Cast.h"
 #include "arm_compute/core/utils/misc/Iterable.h"
@@ -45,11 +46,18 @@
     // Should be in reverse order of execution
     for(auto &node : arm_compute::utils::iterable::reverse_iterate(g.nodes()))
     {
-        if(node && node->type() == NodeType::DepthConcatenateLayer && node->output(0) != nullptr)
+        if(node && node->type() == NodeType::ConcatenateLayer && node->output(0) != nullptr)
         {
             // Get output tensor
             auto output_tensor = node->output(0);
 
+            // Check concatenation axis (Sub-tensor optimization is support for concatenation axis >=2)
+            auto *concat_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node.get());
+            if(output_tensor == nullptr || get_dimension_idx(output_tensor->desc(), concat_node->concatenation_axis()) < 2)
+            {
+                continue;
+            }
+
             // Check that all tensor have the same target and valid inputs
             bool is_valid = std::all_of(node->input_edges().cbegin(), node->input_edges().cend(),
                                         [&](const EdgeID & eid)
@@ -76,7 +84,7 @@
                     depth += input_shape.z();
                 }
 
-                auto *dc_node = arm_compute::utils::cast::polymorphic_downcast<DepthConcatenateLayerNode *>(node.get());
+                auto *dc_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node.get());
                 dc_node->set_enabled(false);
             }
         }