IVGCVSW-8602 Move ComputeSplitAxis() to backendsCommon/WorkloadUtils

* Use ComputeSplitAxis in SplitOperator in tosaCommon mappings
* Fix TosaRef split tests, that were missing outputInfos

Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: Ib577eacdc6399242f37d25494e208aa56db6334c
diff --git a/src/armnn/layers/SplitterLayer.cpp b/src/armnn/layers/SplitterLayer.cpp
index 8a24e0d..b04614b 100644
--- a/src/armnn/layers/SplitterLayer.cpp
+++ b/src/armnn/layers/SplitterLayer.cpp
@@ -9,6 +9,7 @@
 #include <armnn/TypesUtils.hpp>
 #include <armnn/backends/WorkloadData.hpp>
 #include <armnn/backends/WorkloadFactory.hpp>
+#include <backendsCommon/WorkloadUtils.hpp>
 
 namespace armnn
 {
@@ -57,26 +58,6 @@
         // check if split is along the x or y (2 innermost dimensions)
         auto numberOfDimensions = m_Param.GetNumDimensions();
 
-        // Compute split axis within class as aclCommon function causes header issues when included
-        auto ComputeSplitAxis = [&](const armnn::SplitterDescriptor& desc, const TensorShape& input)
-        {
-            unsigned int numSplit = desc.GetNumViews();
-            unsigned int numDimensions = desc.GetNumDimensions();
-            std::set<unsigned int> splitAxis;
-
-            for (unsigned int i = 0; i < numSplit; ++i)
-            {
-                for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
-                {
-                    if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
-                    {
-                        splitAxis.insert(dimIdx);
-                    }
-                }
-            }
-            return splitAxis;
-        };
-
         std::set<unsigned int> axis = ComputeSplitAxis(m_Param, parentInfo.GetShape());
         std::set<unsigned int>::iterator axisIt = axis.begin();
 
diff --git a/src/backends/aclCommon/ArmComputeUtils.hpp b/src/backends/aclCommon/ArmComputeUtils.hpp
index d7025aa..fc77f81 100644
--- a/src/backends/aclCommon/ArmComputeUtils.hpp
+++ b/src/backends/aclCommon/ArmComputeUtils.hpp
@@ -242,32 +242,6 @@
     return aclAxis;
 }
 
-inline std::set<unsigned int> ComputeSplitAxis(const armnn::SplitterDescriptor& desc, const TensorShape& input)
-{
-    unsigned int numSplit = desc.GetNumViews();
-    unsigned int numDimensions = desc.GetNumDimensions();
-    std::set<unsigned int> splitAxis;
-
-    if (desc.HasAxis())
-    {
-        splitAxis.insert(armnnUtils::GetUnsignedAxis(desc.GetNumDimensions(), desc.GetAxis()));
-    }
-    else
-    {
-        for (unsigned int i = 0; i < numSplit; ++i)
-        {
-            for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
-            {
-                if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
-                {
-                    splitAxis.insert(dimIdx);
-                }
-            }
-        }
-    }
-    return splitAxis;
-}
-
 /// Function to convert ArmNN axis (left to right) to ACL axis (right to left) ranging from [-rank, rank)
 inline int ComputeAclAxis(const int& armnnAxis, const armnn::TensorInfo& tensor)
 {
diff --git a/src/backends/backendsCommon/WorkloadUtils.cpp b/src/backends/backendsCommon/WorkloadUtils.cpp
index e36c4b2..d459820 100644
--- a/src/backends/backendsCommon/WorkloadUtils.cpp
+++ b/src/backends/backendsCommon/WorkloadUtils.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017-2023 Arm Ltd. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -8,6 +8,7 @@
 #include <armnn/Utils.hpp>
 #include <armnn/utility/NumericCast.hpp>
 #include <armnnUtils/DataLayoutIndexed.hpp>
+#include <armnnUtils/TensorUtils.hpp>
 
 #include <fmt/format.h>
 #include <numeric>
@@ -373,4 +374,29 @@
     return permutationVector;
 }
 
+std::set<unsigned int> ComputeSplitAxis(const armnn::SplitterDescriptor& desc, const TensorShape& input)
+{
+    unsigned int numSplit = desc.GetNumViews();
+    unsigned int numDimensions = desc.GetNumDimensions();
+    std::set<unsigned int> splitAxis;
+    if (desc.HasAxis())
+    {
+        splitAxis.insert(armnnUtils::GetUnsignedAxis(desc.GetNumDimensions(), desc.GetAxis()));
+    }
+    else
+    {
+        for (unsigned int i = 0; i < numSplit; ++i)
+        {
+            for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
+            {
+                if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
+                {
+                    splitAxis.insert(dimIdx);
+                }
+            }
+        }
+    }
+    return splitAxis;
+}
+
 } // namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadUtils.hpp b/src/backends/backendsCommon/WorkloadUtils.hpp
index 6350c25..0462df6 100644
--- a/src/backends/backendsCommon/WorkloadUtils.hpp
+++ b/src/backends/backendsCommon/WorkloadUtils.hpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017, 2023 Arm Ltd. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -279,4 +279,11 @@
 /// \return - A permutation vector that permutes the 2 last dimensions
 armnn::PermutationVector GeneratePermutationVectorOnLastTwoDimensions(unsigned int rank);
 
+/// Calculates the axis values for split operation.
+///
+/// \param desc - Splitter Descriptor
+/// \param input - Input tensor shape
+/// \return - A set containing axis values of slitter operation
+    std::set<unsigned int> ComputeSplitAxis(const armnn::SplitterDescriptor& desc, const TensorShape& input);
+
 }  //namespace armnn
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index bfe4f6e..9f7d562 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -18,6 +18,7 @@
 #if defined(ARMCOMPUTECL_ENABLED)
 #include <aclCommon/ArmComputeUtils.hpp>
 #include <aclCommon/ArmComputeTensorUtils.hpp>
+#include <backendsCommon/WorkloadUtils.hpp>
 #include "workloads/ClAbsWorkload.hpp"
 #include "workloads/ClAdditionWorkload.hpp"
 #include "workloads/ClActivationWorkload.hpp"
diff --git a/src/backends/cl/workloads/ClSplitterWorkload.cpp b/src/backends/cl/workloads/ClSplitterWorkload.cpp
index ec904eb..074ce5d 100644
--- a/src/backends/cl/workloads/ClSplitterWorkload.cpp
+++ b/src/backends/cl/workloads/ClSplitterWorkload.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2019-2024 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -11,6 +11,7 @@
 #include <aclCommon/ArmComputeUtils.hpp>
 #include <armnn/utility/PolymorphicDowncast.hpp>
 #include <armnn/backends/TensorHandle.hpp>
+#include <backendsCommon/WorkloadUtils.hpp>
 #include <cl/ClTensorHandle.hpp>
 
 
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index ee8f6f2..0298c7c 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -19,6 +19,7 @@
 #if defined(ARMCOMPUTENEON_ENABLED)
 #include <aclCommon/ArmComputeUtils.hpp>
 #include <aclCommon/ArmComputeTensorUtils.hpp>
+#include <backendsCommon/WorkloadUtils.hpp>
 #include "workloads/NeonAbsWorkload.hpp"
 #include "workloads/NeonAdditionWorkload.hpp"
 #include "workloads/NeonActivationWorkload.hpp"
diff --git a/src/backends/neon/workloads/NeonSplitterWorkload.cpp b/src/backends/neon/workloads/NeonSplitterWorkload.cpp
index c307822..bfde497 100644
--- a/src/backends/neon/workloads/NeonSplitterWorkload.cpp
+++ b/src/backends/neon/workloads/NeonSplitterWorkload.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2019-2024 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -12,6 +12,7 @@
 #include <armnn/utility/PolymorphicDowncast.hpp>
 #include <armnn/backends/TensorHandle.hpp>
 #include <neon/NeonTensorHandle.hpp>
+#include <backendsCommon/WorkloadUtils.hpp>
 
 #include "NeonWorkloadUtils.hpp"
 
diff --git a/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp b/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp
index f8c60b1..b733866 100644
--- a/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp
+++ b/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2023-2024 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 // Copyright © 2020 The TensorFlow Authors. All Rights Reserved.
@@ -7,6 +7,7 @@
 //
 
 #include "SplitOperator.hpp"
+#include <backendsCommon/WorkloadUtils.hpp>
 
 // This function is paraphrased from:
 // tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc from function convertSplitOp
@@ -56,26 +57,19 @@
         }
     }
 
-    // Each slice op has a different beginning point.
-    // The size is the same for each slice op.
-    std::vector<int32_t> beginVals;
-    beginVals.reserve(inputs[0]->GetNumDimensions());
-    std::vector<int32_t> sizeVals;
-    sizeVals.reserve(inputs[0]->GetNumDimensions());
-    for (unsigned int j = 0; j < inputs[0]->GetNumDimensions(); ++j)
+    // Configure input and output tensors
+    std::set<unsigned int> splitAxis = ComputeSplitAxis(*splitDescriptor, inputs[0]->GetShape());
+    if (splitAxis.size() != 1)
     {
-        beginVals.emplace_back(0);
-        uint32_t dim = inputs[0]->GetShape()[j];
-        sizeVals.emplace_back(dim);
+        throw InvalidArgumentException("Cannot derive split axis from SplitterDescriptor");
     }
-
-    uint32_t axis = static_cast<uint32_t>(splitDescriptor->GetAxis());
-    sizeVals[axis] = sizeVals[axis] / static_cast<int32_t>(numSplit);
+    uint32_t axis = *splitAxis.begin();
 
     std::vector<TosaSerializationOperator*> ops;
-    for (unsigned int i=0; i < numSplit; ++i)
+    std::vector<int32_t> beginVals(inputs[0]->GetNumDimensions(), 0);
+    for (unsigned int i = 0; i < numSplit; ++i)
     {
-        beginVals[axis] = static_cast<int>(i) * sizeVals[axis];
+        std::vector<int32_t> sizeVals = GetTosaTensorShape(outputs[i]->GetShape());
         TosaSliceAttribute attribute(beginVals, sizeVals);
         auto* op = new TosaSerializationOperator(Op_SLICE,
                                                  Attribute_SliceAttribute,
@@ -84,6 +78,9 @@
                                                  {outputNames[i]});
 
         ops.push_back(op);
+
+        // Update the axis begin value for the next split operation, to be the correct size axis value.
+        beginVals[axis] += sizeVals[axis];
     }
 
     std::vector<TosaSerializationTensor*> tensors;
@@ -98,13 +95,13 @@
         tensors.push_back(new TosaSerializationTensor(inputName, inputShape, inputDType, {}));
     }
 
-    std::vector<int32_t> outputShape = GetTosaTensorShape(outputs[0]->GetShape());
     DType outputDType = ArmNNToDType(outputs[0]->GetDataType());
-
-    for (unsigned int i=0; i < numSplit; ++i)
+    for (unsigned int i = 0; i < numSplit; ++i)
     {
+        std::vector<int32_t> outputShape = GetTosaTensorShape(outputs[i]->GetShape());
         tensors.push_back(new TosaSerializationTensor(outputNames[i], outputShape, outputDType, {}));
     }
+
     // operatorInputNames/operatorOutputNames ends up being the same as
     // blockInputNames/blockOutputNames for one-to-one ArmNN to TOSA mappings
     return new TosaSerializationBasicBlock(blockName, // name
diff --git a/src/backends/tosaCommon/test/OneToManyMappingTests.cpp b/src/backends/tosaCommon/test/OneToManyMappingTests.cpp
index 5a34ac2..991ef15 100644
--- a/src/backends/tosaCommon/test/OneToManyMappingTests.cpp
+++ b/src/backends/tosaCommon/test/OneToManyMappingTests.cpp
@@ -139,8 +139,11 @@
     armnn::TensorInfo inputTensorInfo({1, 18, 4, 4}, DataType::Float32);
     armnn::TensorInfo outputTensorInfo({1, 6, 4, 4}, DataType::Float32);
 
-    TosaSerializationBasicBlock* basicBlock =
-            GetTosaMapping(nullptr, LayerType::Splitter, {&inputTensorInfo}, {&outputTensorInfo}, descriptor);
+    TosaSerializationBasicBlock* basicBlock = GetTosaMapping(nullptr,
+                                                             LayerType::Splitter,
+                                                             {&inputTensorInfo},
+                                                             {&outputTensorInfo, &outputTensorInfo, &outputTensorInfo},
+                                                             descriptor);
 
     VerifySplit(basicBlock,
                 inShape,
diff --git a/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp b/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
index 759b37f..28d7753 100644
--- a/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
+++ b/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
@@ -523,7 +523,7 @@
     TosaRefLayerSupport supportChecker;
     std::string reasonIfNotSupported;
     auto supported = supportChecker.IsLayerSupported(LayerType::Splitter,
-                                                     {in, out},
+                                                     {in, out, out, out},
                                                      descriptor,
                                                      EmptyOptional(),
                                                      EmptyOptional(),
@@ -547,7 +547,7 @@
     TosaRefLayerSupport supportChecker;
     std::string reasonIfNotSupported;
     auto supported = supportChecker.IsLayerSupported(LayerType::Splitter,
-                                                     {in, out},
+                                                     {in, out, out, out},
                                                      descriptor,
                                                      EmptyOptional(),
                                                      EmptyOptional(),