COMPMID-2111: ConcatenateLayer API should accept an index instead of an enum

Alters the concatenate layer to be layout agnostic and accept an index
as thec concatenation axis instead of an typed layout dependent
enumeration.

Change-Id: I0eaaf919f66a1ba1b09bbfb47c171fc1d4045530
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/994
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index b96a242..9f8dd69 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -221,14 +221,15 @@
 
     // Get input tensor descriptor
     const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
+    const DataLayout       input_data_layout = input_tensor_desc.layout;
 
     // Create weights node
     TensorDescriptor w_desc = input_tensor_desc;
-    w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width);
-    w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height);
-    w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL),
+    w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::WIDTH), kernel_spatial_extend.width);
+    w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height);
+    w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::CHANNEL),
                      get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL) / num_groups);
-    w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::BATCHES), depth);
+    w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::BATCHES), depth);
     if(!weights_quant_info.empty())
     {
         w_desc.quant_info = weights_quant_info;
@@ -275,14 +276,15 @@
 
     // Get input tensor descriptor
     const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
+    const DataLayout       input_data_layout = input_tensor_desc.layout;
 
     // Create weights node
     TensorDescriptor w_desc = input_tensor_desc;
-    w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width);
-    w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height);
-    w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL),
+    w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::WIDTH), kernel_spatial_extend.width);
+    w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height);
+    w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::CHANNEL),
                      get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL));
-    w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::BATCHES), depth);
+    w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::BATCHES), depth);
 
     NodeID w_nid = add_const_node_with_name(g, params, "Weights", w_desc, std::move(weights_accessor));
 
@@ -328,12 +330,13 @@
 
     // Get input tensor descriptor
     const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
+    const DataLayout       input_data_layout = input_tensor_desc.layout;
 
     // Create weights node
     TensorDescriptor w_desc = input_tensor_desc;
-    w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), kernel_spatial_extend.width);
-    w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height);
-    w_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL),
+    w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::WIDTH), kernel_spatial_extend.width);
+    w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::HEIGHT), kernel_spatial_extend.height);
+    w_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::CHANNEL),
                      get_dimension_size(input_tensor_desc, DataLayoutDimension::CHANNEL) * depth_multiplier);
     if(!quant_info.empty())
     {
@@ -595,13 +598,14 @@
 
     // Get input tensor descriptor
     const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
+    const DataLayout       input_data_layout = input_tensor_desc.layout;
 
     // Create mul node
     TensorDescriptor mul_desc = input_tensor_desc;
-    const size_t     C        = input_tensor_desc.shape[get_dimension_idx(mul_desc, DataLayoutDimension::CHANNEL)];
-    mul_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::WIDTH), 1);
-    mul_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::HEIGHT), 1);
-    mul_desc.shape.set(get_dimension_idx(input_tensor_desc, DataLayoutDimension::CHANNEL), C);
+    const size_t     C        = input_tensor_desc.shape[get_dimension_idx(input_data_layout, DataLayoutDimension::CHANNEL)];
+    mul_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::WIDTH), 1);
+    mul_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::HEIGHT), 1);
+    mul_desc.shape.set(get_dimension_idx(input_data_layout, DataLayoutDimension::CHANNEL), C);
     NodeID      mul_const_nid   = add_const_node_with_name(g, params, "Mul", mul_desc, std::move(mul_accessor));
     NodeIdxPair mul_const_nidxp = { mul_const_nid, 0 };