IVGCVSW-5107 Allow Split to use subtensor on x and y

Signed-off-by: Keith Davis <keith.davis@arm.com>
Change-Id: I2370d260b750f36842c23f08e8a00ccf976d0aed
diff --git a/src/armnn/layers/SplitterLayer.cpp b/src/armnn/layers/SplitterLayer.cpp
index 2d469b0..72f27f7 100644
--- a/src/armnn/layers/SplitterLayer.cpp
+++ b/src/armnn/layers/SplitterLayer.cpp
@@ -33,7 +33,7 @@
 }
 
 template<typename FactoryType>
-void SplitterLayer::CreateTensors(const FactoryType& factory)
+void SplitterLayer::CreateTensors(const TensorHandleFactoryRegistry& registry, const FactoryType& factory)
 {
     //If sub tensors are supported than all the "splitter" need to do is to
     //set the outputs to be appropriate sub tensors of the input.
@@ -41,8 +41,9 @@
 
     if (useSubTensors)
     {
-        const OutputSlot* slot = GetInputSlots()[0].GetConnectedOutputSlot();
+        // Get outputHandler of previous layer
         const OutputHandler& outputHandler = GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler();
+        const OutputSlot* slot = GetInputSlots()[0].GetConnectedOutputSlot();
 
         const TensorInfo& parentInfo = outputHandler.GetTensorInfo();
 
@@ -50,6 +51,36 @@
 
         std::vector<std::unique_ptr<ITensorHandle>> subTensors;
 
+        // check if split is along the x or y (2 innermost dimensions)
+        auto numberOfDimensions = m_Param.GetNumDimensions();
+
+        // Compute split axis within class as aclCommon function causes header issues when included
+        auto ComputeSplitAxis = [&](const armnn::SplitterDescriptor& desc, const TensorShape& input)
+        {
+            unsigned int numSplit = desc.GetNumViews();
+            unsigned int numDimensions = desc.GetNumDimensions();
+            std::set<unsigned int> splitAxis;
+
+            for (unsigned int i = 0; i < numSplit; ++i)
+            {
+                for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
+                {
+                    if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
+                    {
+                        splitAxis.insert(dimIdx);
+                    }
+                }
+            }
+            return splitAxis;
+        };
+
+        std::set<unsigned int> axis = ComputeSplitAxis(m_Param, parentInfo.GetShape());
+        std::set<unsigned int>::iterator axisIt = axis.begin();
+
+        bool isOnXorY = m_Param.GetNumDimensions() >= 3 &&
+                            ((*axisIt == numberOfDimensions - 1) ||
+                                (*axisIt == numberOfDimensions - 2));
+
         //Creates the outputs as subtensors of the input.
         for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
         {
@@ -57,11 +88,50 @@
 
             OutputSlot& outSlot = GetOutputSlot(i);
             ITensorHandleFactory::FactoryId factoryId = outSlot.GetTensorHandleFactoryId();
+
+            const unsigned int numOutputSlots = GetNumOutputSlots();
+
+            // if split along x or y (2 innermost dimensions) and the next layers do not require padding
+            bool canUseSubTensorOnXorY = true;
+            bool isTensorHandleFactory = std::is_same<armnn::ITensorHandleFactory, FactoryType>::value;
+            if (isTensorHandleFactory)
+            {
+                for (unsigned int it = 0; it < numOutputSlots; ++it)
+                {
+                    InputSlot* inputSlot = GetOutputSlot(it).GetConnection(0);
+                    ITensorHandleFactory* handleFactory  = registry.GetFactory(factoryId);
+                    std::vector<Capability> capabilities =
+                        handleFactory->GetCapabilities(&(inputSlot->GetOwningLayer()),
+                                                       this,
+                                                       CapabilityClass::PaddingRequired);
+                    if (isOnXorY)
+                    {
+                        canUseSubTensorOnXorY = false;
+                        if (capabilities.empty())
+                        {
+                            canUseSubTensorOnXorY = true;
+                        }
+                    }
+
+                    if (!canUseSubTensorOnXorY)
+                    {
+                        break;
+                    }
+                }
+            }
+
             auto CreateSubTensor = [&]()
             {
-                // Make sure quantization parameters are in the same space
-                if (parentInfo.IsTypeSpaceMatch(info) &&
-                    factoryId == slot->GetTensorHandleFactoryId())
+                // Make sure:
+                // 1) quantization parameters are in the same space
+                // 2) the same TensorHandleFactory is used for input and split layer output
+                // 3) the output does not go to a Constant layer or input layer
+                // 4) if split along x or y (2 innermost dimensions) and the next layers do not require padding
+                if (parentInfo.IsTypeSpaceMatch(info) && //(1)
+                    factoryId == slot->GetTensorHandleFactoryId() && //(2)
+                    GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Constant && //(3)
+                    GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Input && //(3)
+                    canUseSubTensorOnXorY) //(4)
                 {
                     return factory.CreateSubTensorHandle(*inputData,
                                                          info.GetShape(),
@@ -87,7 +157,6 @@
                 m_OutputHandlers[i].SetData(std::move(subTensor));
                 ++i;
             }
-
         }
     }
 
@@ -110,13 +179,13 @@
 
     if (factoryId == ITensorHandleFactory::LegacyFactoryId)
     {
-        CreateTensors(workloadFactory);
+        CreateTensors(registry, workloadFactory);
     }
     else
     {
         ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId);
         ARMNN_ASSERT(handleFactory);
-        CreateTensors(*handleFactory);
+        CreateTensors(registry, *handleFactory);
     }
 }
 
diff --git a/src/armnn/layers/SplitterLayer.hpp b/src/armnn/layers/SplitterLayer.hpp
index bd20890..ae725b9 100644
--- a/src/armnn/layers/SplitterLayer.hpp
+++ b/src/armnn/layers/SplitterLayer.hpp
@@ -57,7 +57,7 @@
 
 private:
     template <typename FactoryType>
-    void CreateTensors(const FactoryType& factory);
+    void CreateTensors(const TensorHandleFactoryRegistry& registry, const FactoryType& factory);
 };
 
 } // namespace
diff --git a/src/backends/backendsCommon/test/CommonTestUtils.hpp b/src/backends/backendsCommon/test/CommonTestUtils.hpp
index e96edc8..8c4da62 100644
--- a/src/backends/backendsCommon/test/CommonTestUtils.hpp
+++ b/src/backends/backendsCommon/test/CommonTestUtils.hpp
@@ -8,9 +8,11 @@
 #include <Graph.hpp>
 #include <SubgraphView.hpp>
 #include <SubgraphViewSelector.hpp>
+#include <ResolveType.hpp>
 
 #include <armnn/BackendRegistry.hpp>
 
+#include <armnn/Types.hpp>
 #include <backendsCommon/CpuTensorHandle.hpp>
 
 #include <test/TestUtils.hpp>
@@ -50,6 +52,23 @@
     return map.find(key) != map.end();
 }
 
+// Utility template for comparing tensor elements
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+bool Compare(T a, T b, float tolerance = 0.000001f)
+{
+    if (ArmnnType == armnn::DataType::Boolean)
+    {
+        // NOTE: Boolean is represented as uint8_t (with zero equals
+        // false and everything else equals true), therefore values
+        // need to be casted to bool before comparing them
+        return static_cast<bool>(a) == static_cast<bool>(b);
+    }
+
+    // NOTE: All other types can be cast to float and compared with
+    // a certain level of tolerance
+    return std::fabs(static_cast<float>(a) - static_cast<float>(b)) <= tolerance;
+}
+
 template <typename ConvolutionLayer>
 void SetWeightAndBias(ConvolutionLayer* layer, const armnn::TensorInfo& weightInfo, const armnn::TensorInfo& biasInfo)
 {
diff --git a/src/backends/neon/test/NeonTensorHandleTests.cpp b/src/backends/neon/test/NeonTensorHandleTests.cpp
index c6a562f..c881632 100644
--- a/src/backends/neon/test/NeonTensorHandleTests.cpp
+++ b/src/backends/neon/test/NeonTensorHandleTests.cpp
@@ -2,7 +2,6 @@
 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
-
 #include <Graph.hpp>
 #include <Network.hpp>
 
@@ -13,8 +12,10 @@
 
 #include <test/GraphUtils.hpp>
 #include <arm_compute/runtime/Allocator.h>
+#include <backendsCommon/test/CommonTestUtils.hpp>
 
 #include <boost/test/unit_test.hpp>
+#include <armnn/utility/Assert.hpp>
 
 BOOST_AUTO_TEST_SUITE(NeonTensorHandleTests)
 using namespace armnn;
@@ -86,7 +87,7 @@
     BOOST_TEST(capabilities[0].m_Value);
 }
 
-BOOST_AUTO_TEST_CASE(ConcatOnXorYSubTensorsNoPaddinRequiredTest)
+BOOST_AUTO_TEST_CASE(ConcatOnXorYSubTensorsNoPaddingRequiredTest)
 {
     armnn::INetworkPtr net(armnn::INetwork::Create());
 
@@ -156,7 +157,469 @@
                 }
             }
             // sub-tensors should be supported in this configuration
-            BOOST_CHECK(numberOfSubTensors > 0);
+            ARMNN_ASSERT(numberOfSubTensors > 0);
+        }
+    }
+}
+
+BOOST_AUTO_TEST_CASE(ConcatonXorYPaddingRequiredTest)
+{
+    armnn::INetworkPtr net(armnn::INetwork::Create());
+
+    // Set up tensor infos
+    const armnn::TensorInfo inputInfo = armnn::TensorInfo({2, 3, 2, 2}, armnn::DataType::Float32);
+    const armnn::TensorInfo intermediateInfo = armnn::TensorInfo({2, 3, 2, 2}, armnn::DataType::Float32);
+    const armnn::TensorInfo outputInfo = armnn::TensorInfo({2, 3, 4, 2}, armnn::DataType::Float32);
+
+    armnn::Pooling2dDescriptor descriptor;
+    descriptor.m_PoolType = armnn::PoolingAlgorithm::Average;
+    descriptor.m_PoolWidth = descriptor.m_PoolHeight = 3;
+    descriptor.m_StrideX = descriptor.m_StrideY = 1;
+    descriptor.m_PadLeft = 1;
+    descriptor.m_PadRight = 1;
+    descriptor.m_PadTop = 1;
+    descriptor.m_PadBottom = 1;
+    descriptor.m_PaddingMethod = armnn::PaddingMethod::IgnoreValue;
+
+    // Create the network
+    armnn::IConnectableLayer* const input0Layer = net->AddInputLayer(0, "input_0");
+    input0Layer->GetOutputSlot(0).SetTensorInfo(inputInfo);
+    armnn::IConnectableLayer* pooling2dLayer0 = net->AddPooling2dLayer(descriptor, "pooling2d_0");
+    pooling2dLayer0->GetOutputSlot(0).SetTensorInfo(intermediateInfo);
+    input0Layer->GetOutputSlot(0).Connect(pooling2dLayer0->GetInputSlot(0));
+
+    armnn::IConnectableLayer* const input1Layer = net->AddInputLayer(1, "input_1");
+    input1Layer->GetOutputSlot(0).SetTensorInfo(inputInfo);
+    armnn::IConnectableLayer* pooling2dLayer1 = net->AddPooling2dLayer(descriptor, "pooling2d_1");
+    pooling2dLayer1->GetOutputSlot(0).SetTensorInfo(intermediateInfo);
+    input1Layer->GetOutputSlot(0).Connect(pooling2dLayer1->GetInputSlot(0));
+
+    std::array<armnn::TensorShape, 2> concatInputShapes = { intermediateInfo.GetShape(), intermediateInfo.GetShape() };
+    armnn::IConnectableLayer* const concatLayer = net->AddConcatLayer(armnn::CreateDescriptorForConcatenation(
+        concatInputShapes.begin(), concatInputShapes.end(), 2), "concatenation");
+    concatLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+    pooling2dLayer0->GetOutputSlot(0).Connect(concatLayer->GetInputSlot(0));
+    pooling2dLayer1->GetOutputSlot(0).Connect(concatLayer->GetInputSlot(1));
+
+    armnn::IConnectableLayer* const outputLayer = net->AddOutputLayer(0, "output");
+    concatLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+    armnn::IRuntime::CreationOptions options;
+    armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
+
+    std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
+    armnn::IOptimizedNetworkPtr optimizedNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec());
+
+    const armnn::Graph& theGraph = static_cast<armnn::OptimizedNetwork*>(optimizedNet.get())->GetGraph();
+
+    // Load graph into runtime
+    armnn::NetworkId networkIdentifier;
+    runtime->LoadNetwork(networkIdentifier, std::move(optimizedNet));
+
+    // now check the concat how many sub-tensors it is using..
+    auto TraceSubTensorHandleAncestry = [](armnn::ITensorHandle* const subTensorHandle)
+    {
+        if (subTensorHandle && subTensorHandle->GetParent())
+        {
+            return true;
+        }
+        return false;
+    };
+
+    unsigned int numberOfSubTensors = 0;
+    for (auto&& layer : theGraph)
+    {
+        if(layer->GetType() == armnn::LayerType::Concat)
+        {
+            for (unsigned int i = 0; i < layer->GetNumInputSlots(); ++i)
+            {
+                const armnn::OutputSlot* slot = layer->GetInputSlot(i).GetConnectedOutputSlot();
+                if (TraceSubTensorHandleAncestry(slot->GetOutputHandler().GetData()))
+                {
+                    ++numberOfSubTensors;
+                }
+            }
+        }
+    }
+    // sub-tensors should not be supported in this configuration
+    ARMNN_ASSERT(numberOfSubTensors == 0);
+}
+
+BOOST_AUTO_TEST_CASE(SplitteronXorYNoPaddingRequiredTest)
+{
+    using namespace armnn;
+
+    unsigned int splitAxis = 2;
+    unsigned int numSplit = 2;
+
+    const TensorShape& inputShape = { 2, 3, 4, 2 };
+    const armnn::TensorInfo intermediateInfo = armnn::TensorInfo({ 2, 3, 2, 2 }, armnn::DataType::Float32);
+    const std::vector<TensorShape> outputShapes{{ 2, 3, 2, 2 },
+                                                { 2, 3, 2, 2 }};
+    const float qScale = 1.0f;
+    const int32_t qOffset = 0;
+
+    // Creates structures for input & output.
+    std::vector<float> inputData{
+            1, 2,
+            3, 4,
+            5, 6,
+            7, 8,
+            9, 10,
+            11, 12,
+            13, 14,
+            15, 16,
+            17, 18,
+            19, 20,
+            21, 22,
+            23, 24,
+            25, 26,
+            27, 28,
+            29, 30,
+            31, 32,
+            33, 34,
+            35, 36,
+            37, 38,
+            39, 40,
+            41, 42,
+            43, 44,
+            45, 46,
+            47, 48
+    };
+
+    std::vector<float> expectedOutput0{
+            1, 2,
+            3, 4,
+            9, 10,
+            11, 12,
+            17, 18,
+            19, 20,
+            25, 26,
+            27, 28,
+            33, 34,
+            35, 36,
+            41, 42,
+            43, 44
+    };
+
+    std::vector<float> expectedOutput1{
+            5, 6,
+            7, 8,
+            13, 14,
+            15, 16,
+            21, 22,
+            23, 24,
+            29, 30,
+            31, 32,
+            37, 38,
+            39, 40,
+            45, 46,
+            47, 48
+    };
+
+    // Builds up the structure of the network.
+    INetworkPtr net(INetwork::Create());
+
+    TensorInfo inputTensorInfo(inputShape, armnn::DataType::Float32, qScale, qOffset);
+
+    armnn::ElementwiseUnaryDescriptor descriptor(armnn::UnaryOperation::Abs);
+
+    // Splitter
+    std::vector<unsigned int> splitterDimSizes(inputShape.GetNumDimensions());
+
+    // Add current input shape to splitterDimSizes
+    for (unsigned int i = 0; i < inputShape.GetNumDimensions(); ++i)
+    {
+        splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
+    }
+
+    if (splitterDimSizes[splitAxis] % numSplit != 0)
+    {
+        throw ParseException("Number of splits must evenly divide the dimension");
+    }
+
+    splitterDimSizes[splitAxis] /= numSplit;
+
+    SplitterDescriptor splitDesc(numSplit, inputShape.GetNumDimensions());
+
+    for (unsigned int g = 0; g < numSplit; ++g)
+    {
+        // Set the size of the views.
+        for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
+        {
+            splitDesc.SetViewSize(g, dimIdx, splitterDimSizes[dimIdx]);
+        }
+        splitDesc.SetViewOriginCoord(g, splitAxis, splitterDimSizes[splitAxis] * g);
+    }
+    IConnectableLayer* input = net->AddInputLayer(0, "input");
+    IConnectableLayer* elementWiseUnary0 = net->AddElementwiseUnaryLayer(descriptor, "elementwiseunary_0");
+    IConnectableLayer* elementWiseUnary1 = net->AddElementwiseUnaryLayer(descriptor, "elementwiseunary_0");
+    IConnectableLayer* splitter = net->AddSplitterLayer(splitDesc, "splitter");
+
+    // Connections
+    Connect(input, splitter, inputTensorInfo, 0, 0);
+    Connect(splitter, elementWiseUnary0, intermediateInfo, 0, 0);
+    Connect(splitter, elementWiseUnary1, intermediateInfo, 1, 0);
+
+    std::vector<IConnectableLayer*> pooling2dLayers{elementWiseUnary0, elementWiseUnary1};
+
+    for (unsigned int i = 0; i < outputShapes.size(); ++i)
+    {
+        TensorInfo outputTensorInfo(outputShapes[i], armnn::DataType::Float32, qScale, qOffset);
+        IConnectableLayer* output = net->AddOutputLayer(boost::numeric_cast<LayerBindingId>(i));
+        Connect(pooling2dLayers[i], output, outputTensorInfo, 0, 0);
+    }
+
+    std::map<int, std::vector<float>> inputTensorData = {{ 0,inputData }};
+    std::map<int, std::vector<float>> expectedOutputData = {{ 0, expectedOutput0 }, { 1, expectedOutput1 }};
+
+    armnn::IRuntime::CreationOptions options;
+    armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
+
+    std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
+    armnn::IOptimizedNetworkPtr optimizedNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec());
+
+    const armnn::Graph& theGraph = static_cast<armnn::OptimizedNetwork*>(optimizedNet.get())->GetGraph();
+
+    // Load graph into runtime
+    armnn::NetworkId networkIdentifier;
+    runtime->LoadNetwork(networkIdentifier, std::move(optimizedNet));
+
+    // now check the concat how many sub-tensors it is using..
+    auto TraceSubTensorHandleAncestry = [](armnn::ITensorHandle* const subTensorHandle)
+    {
+        if (subTensorHandle && subTensorHandle->GetParent())
+        {
+            return true;
+        }
+        return false;
+    };
+
+    for (auto&& layer : theGraph)
+    {
+        if(layer->GetType() == armnn::LayerType::ElementwiseUnary)
+        {
+            unsigned int numberOfSubTensors = 0;
+            for (unsigned int i = 0; i < layer->GetNumInputSlots(); ++i)
+            {
+                const armnn::OutputSlot* slot = layer->GetInputSlot(i).GetConnectedOutputSlot();
+                if (TraceSubTensorHandleAncestry(slot->GetOutputHandler().GetData()))
+                {
+                    ++numberOfSubTensors;
+                }
+            }
+            // sub-tensors should be supported in this configuration
+            ARMNN_ASSERT(numberOfSubTensors > 0);
+        }
+    }
+
+    InputTensors inputTensors;
+    inputTensors.reserve(inputTensorData.size());
+    for (auto&& it : inputTensorData)
+    {
+        inputTensors.push_back({it.first,
+                              ConstTensor(runtime->GetInputTensorInfo(networkIdentifier, it.first), it.second.data())});
+    }
+    OutputTensors outputTensors;
+    outputTensors.reserve(expectedOutputData.size());
+    std::map<int, std::vector<float>> outputStorage;
+    for (auto&& it : expectedOutputData)
+    {
+        std::vector<float> out(it.second.size());
+        outputStorage.emplace(it.first, out);
+        outputTensors.push_back({it.first,
+                                 Tensor(runtime->GetOutputTensorInfo(networkIdentifier, it.first),
+                                               outputStorage.at(it.first).data())});
+    }
+
+    // Does the inference.
+    runtime->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
+
+    // Checks the results.
+    float tolerance = 0.000001f;
+    for (auto&& it : expectedOutputData)
+    {
+        std::vector<float> out = outputStorage.at(it.first);
+        for (unsigned int i = 0; i < out.size(); ++i)
+        {
+            BOOST_CHECK_MESSAGE(Compare<armnn::DataType::Float32>(it.second[i], out[i], tolerance) == true,
+                    "Actual output: " << out[i] << ". Expected output:" << it.second[i]);
+
+        }
+    }
+}
+
+BOOST_AUTO_TEST_CASE(SplitteronXorYPaddingRequiredTest)
+{
+    using namespace armnn;
+
+    unsigned int splitAxis = 2;
+    unsigned int numSplit = 2;
+
+    const TensorShape& inputShape = { 1, 1, 4, 4 };
+    const armnn::TensorInfo intermediateInfo = armnn::TensorInfo({ 1, 1, 2, 4 }, armnn::DataType::Float32);
+    const std::vector<TensorShape> outputShapes{{ 1, 1, 2, 4 },
+                                                { 1, 1, 2, 4 }};
+
+    const float qScale = 1.0f;
+    const int32_t qOffset = 0;
+
+    // Creates structures for input & output.
+    std::vector<float> inputData{
+        9.0f,   27.0f,  18.0f,  36.0f,
+        18.0f,   9.0f,  18.0f,   9.0f,
+        27.0f,  18.0f,   9.0f,  27.0f,
+        9.0f,   27.0f,   9.0f,  18.0f,
+    };
+
+    std::vector<float> expectedOutput0{
+         7.0f,  11.0f,  13.0f, 9.0f,
+         7.0f,  11.0f,  13.0f, 9.0f
+    };
+
+    std::vector<float> expectedOutput1{
+        9.0f,  11.0f,  12.0f, 7.0f,
+        9.0f,  11.0f,  12.0f, 7.0f
+    };
+
+    // Builds up the structure of the network.
+    INetworkPtr net(INetwork::Create());
+
+    TensorInfo inputTensorInfo(inputShape, armnn::DataType::Float32, qScale, qOffset);
+
+    // Pooling
+    armnn::Pooling2dDescriptor descriptor;
+    descriptor.m_PoolType = armnn::PoolingAlgorithm::Average;
+    descriptor.m_PoolWidth = descriptor.m_PoolHeight = 3;
+    descriptor.m_StrideX = descriptor.m_StrideY = 1;
+    descriptor.m_PadLeft = 1;
+    descriptor.m_PadRight = 1;
+    descriptor.m_PadTop = 1;
+    descriptor.m_PadBottom = 1;
+    descriptor.m_PaddingMethod = armnn::PaddingMethod::IgnoreValue;
+
+    // Splitter
+    std::vector<unsigned int> splitterDimSizes(inputShape.GetNumDimensions());
+
+    // Add current input shape to splitterDimSizes
+    for (unsigned int i = 0; i < inputShape.GetNumDimensions(); ++i)
+    {
+        splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
+    }
+
+    if (splitterDimSizes[splitAxis] % numSplit != 0)
+    {
+        throw ParseException("Number of splits must evenly divide the dimension");
+    }
+
+    splitterDimSizes[splitAxis] /= numSplit;
+
+    SplitterDescriptor splitDesc(numSplit, inputShape.GetNumDimensions());
+
+    for (unsigned int g = 0; g < numSplit; ++g)
+    {
+        // Set the size of the views.
+        for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
+        {
+            splitDesc.SetViewSize(g, dimIdx, splitterDimSizes[dimIdx]);
+        }
+        splitDesc.SetViewOriginCoord(g, splitAxis, splitterDimSizes[splitAxis] * g);
+    }
+
+    IConnectableLayer* input = net->AddInputLayer(0, "input");
+    IConnectableLayer* pooling2d0 = net->AddPooling2dLayer(descriptor, "pooling2d_0");
+    IConnectableLayer* pooling2d1 = net->AddPooling2dLayer(descriptor, "pooling2d_1");
+    IConnectableLayer* splitter = net->AddSplitterLayer(splitDesc, "splitter");
+
+    // Connections
+    Connect(input, splitter, inputTensorInfo, 0, 0);
+    Connect(splitter, pooling2d0, intermediateInfo, 0, 0);
+    Connect(splitter, pooling2d1, intermediateInfo, 1, 0);
+
+    std::vector<IConnectableLayer*> pooling2dLayers{pooling2d0, pooling2d1};
+
+    for (unsigned int i = 0; i < outputShapes.size(); ++i)
+    {
+        TensorInfo outputTensorInfo(outputShapes[i], armnn::DataType::Float32, qScale, qOffset);
+        IConnectableLayer* output = net->AddOutputLayer(boost::numeric_cast<LayerBindingId>(i));
+        Connect(pooling2dLayers[i], output, outputTensorInfo, 0, 0);
+    }
+
+    std::map<int, std::vector<float>> inputTensorData = {{ 0,inputData }};
+    std::map<int, std::vector<float>> expectedOutputData = {{ 0, expectedOutput0 }, { 1, expectedOutput1 }};
+
+    armnn::IRuntime::CreationOptions options;
+    armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
+
+    std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
+    armnn::IOptimizedNetworkPtr optimizedNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec());
+
+    const armnn::Graph& theGraph = static_cast<armnn::OptimizedNetwork*>(optimizedNet.get())->GetGraph();
+
+    // Load graph into runtime
+    armnn::NetworkId networkIdentifier;
+    runtime->LoadNetwork(networkIdentifier, std::move(optimizedNet));
+
+    // now check the concat how many sub-tensors it is using..
+    auto TraceSubTensorHandleAncestry = [](armnn::ITensorHandle* const subTensorHandle)
+    {
+        if (subTensorHandle && subTensorHandle->GetParent())
+        {
+            return true;
+        }
+        return false;
+    };
+
+    for (auto&& layer : theGraph)
+    {
+        if(layer->GetType() == armnn::LayerType::Pooling2d)
+        {
+            unsigned int numberOfSubTensors = 0;
+            for (unsigned int i = 0; i < layer->GetNumInputSlots(); ++i)
+            {
+                const armnn::OutputSlot* slot = layer->GetInputSlot(i).GetConnectedOutputSlot();
+                if (TraceSubTensorHandleAncestry(slot->GetOutputHandler().GetData()))
+                {
+                    ++numberOfSubTensors;
+                }
+            }
+            // sub-tensors should be supported in this configuration
+            ARMNN_ASSERT(numberOfSubTensors == 0);
+        }
+    }
+
+    InputTensors inputTensors;
+    inputTensors.reserve(inputTensorData.size());
+    for (auto&& it : inputTensorData)
+    {
+        inputTensors.push_back({it.first,
+                              ConstTensor(runtime->GetInputTensorInfo(networkIdentifier, it.first), it.second.data())});
+    }
+    OutputTensors outputTensors;
+    outputTensors.reserve(expectedOutputData.size());
+    std::map<int, std::vector<float>> outputStorage;
+    for (auto&& it : expectedOutputData)
+    {
+        std::vector<float> out(it.second.size());
+        outputStorage.emplace(it.first, out);
+        outputTensors.push_back({it.first,
+                                 Tensor(runtime->GetOutputTensorInfo(networkIdentifier, it.first),
+                                               outputStorage.at(it.first).data())});
+    }
+
+    // Does the inference.
+    runtime->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
+
+    // Checks the results.
+    float tolerance = 0.000001f;
+    for (auto&& it : expectedOutputData)
+    {
+        std::vector<float> out = outputStorage.at(it.first);
+        for (unsigned int i = 0; i < out.size(); ++i)
+        {
+            BOOST_CHECK_MESSAGE(Compare<armnn::DataType::Float32>(it.second[i], out[i], tolerance) == true,
+                    "Actual output: " << out[i] << ". Expected output:" << it.second[i]);
+
         }
     }
 }