IVGCVSW-3536 Add Axis parameter to reference Softmax implementation

 * Add Axis parameter to Softmax Descriptor
 * Add new reference implementation for Softmax using Axis parameter
 * Add unit tests to cover each Axis

Change-Id: Iafac2275d2212337456f2b1b56b0f76f77fb9543
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
diff --git a/src/backends/backendsCommon/test/LayerTests.cpp b/src/backends/backendsCommon/test/LayerTests.cpp
index d6e0e87..b40a3f5 100644
--- a/src/backends/backendsCommon/test/LayerTests.cpp
+++ b/src/backends/backendsCommon/test/LayerTests.cpp
@@ -77,6 +77,36 @@
 // 2-channel bias used by a number of Conv2d tests.
 static std::vector<float> Bias2({0, 2});
 
+struct Simple3dSoftmaxOutputData
+{
+    const std::vector<float> outputData =
+            {
+                0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
+                0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f
+            };
+
+    const armnn::TensorShape inputShape{ 1, 8, 1 };
+
+    const std::vector<float> inputData =
+            {
+                    0.f, 1.f, 0.f, 0.f,
+                    .5f, 0.f, 0.f, 0.f,
+            };
+};
+
+struct Simple4dSoftmaxData
+{
+    const armnn::TensorShape inputShape{ 1, 8, 1, 1 };
+
+    const std::vector<float> outputData = { 0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
+                                            0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f };
+    const std::vector<float> inputData =
+            {
+                    0.f, 1.f, 0.f, 0.f,
+                    .5f, 0.f, 0.f, 0.f
+            };
+};
+
 // Helper function that returns either Bias2 or an empty vector depending on whether bias is enabled.
 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 boost::multi_array<T, 1> GetBias2(bool biasEnabled, float qScale)
@@ -1647,12 +1677,117 @@
     return SimpleSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta);
 }
 
+LayerTestResult<float,2> SimpleAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis)
+{
+    return SimpleSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta, axis);
+}
+
 LayerTestResult<float,3> Simple3dSoftmaxTest(
         armnn::IWorkloadFactory& workloadFactory,
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta);
+    Simple3dSoftmaxOutputData data;
+    return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta,
+                                                             data.inputShape, data.outputData, data.inputData);
+}
+
+LayerTestResult<float,3> Simple3dAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis)
+{
+    armnn::TensorShape inputShape;
+    std::vector<float> inputData;
+    std::vector<float> outputData;
+    switch (axis)
+    {
+    case -3:
+    case 0:
+        {
+            inputShape = {5, 2, 2};
+
+            inputData =
+                    {
+                            17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
+
+                            15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
+                            0.087144312427294f,
+
+                            0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
+                            0.032058600957022f,
+                            0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f
+                    };
+            break;
+        }
+    case -2:
+    case 1:
+        {
+            inputShape = {2, 5, 2};
+
+            inputData =
+                    {
+                            17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
+
+                            17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                            0.087144312427294f,
+                            0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f,
+
+                            0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                            0.087144312427294f,
+                            0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f
+                    };
+        break;
+        }
+    case -1:
+    case 2:
+        {
+            inputShape = {2, 2, 5};
+
+            inputData =
+                    {
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f
+                    };
+            break;
+        }
+    }
+
+    return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta,
+                                                             inputShape, outputData, inputData, axis);
 }
 
 LayerTestResult<float,4> Simple4dSoftmaxTest(
@@ -1660,7 +1795,167 @@
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta);
+    Simple4dSoftmaxData data;
+    return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta, data.inputShape,
+                                                             data.outputData, data.inputData);
+}
+
+LayerTestResult<float,4> Simple4dAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis)
+{
+    armnn::TensorShape inputShape;
+    std::vector<float> inputData;
+    std::vector<float> outputData;
+    switch (axis)
+    {
+    case -4:
+    case 0:
+        {
+            inputShape = {5, 2, 2, 2};
+
+            inputData =
+                    {
+                            17.0f, -1.0f, 17.0f, -1.0f, 17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f,
+                            16.0f, -2.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f, 15.0f, -3.0f,
+                            15.0f, -3.0f, 15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 14.0f, -4.0f,
+                            14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f, 1.0f, -17.0f, 1.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
+                            0.643914213228014f,
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.236882800924671f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.236882800924671f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.087144312427294f, 0.087144312427294f, 0.087144312427294f,
+                            0.087144312427294f,
+
+                            0.087144312427294f, 0.087144312427294f, 0.087144312427294f, 0.087144312427294f,
+                            0.032058600957022f,
+                            0.032058600957022f, 0.032058600957022f, 0.032058600957022f, 0.032058600957022f,
+                            0.032058600957022f,
+                            0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f,
+                            7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f, 7.246299848982885e-08f
+                    };
+            break;
+        }
+    case -3:
+    case 1:
+        {
+            inputShape = {2, 5, 2, 2};
+
+            inputData =
+                    {
+                            17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
+                            15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f,
+                            17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
+                            15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
+                            0.087144312427294f,
+                            0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
+                            0.032058600957022f,
+                            0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f,
+
+
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
+                            0.087144312427294f,
+                            0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
+                            0.032058600957022f,
+                            0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f
+                    };
+            break;
+        }
+    case -2:
+    case 2:
+        {
+        inputShape = {2, 2, 5, 2};
+
+        inputData =
+                {
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
+                };
+
+        outputData =
+                {
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f,
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f,
+
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f,
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f
+                };
+        break;
+        }
+    case -1:
+    case 3:
+        {
+            inputShape = {2, 2, 2, 5};
+
+            inputData =
+                    {
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f
+                    };
+            break;
+        }
+    }
+
+    return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta, inputShape,
+                                                             outputData, inputData, axis);
 }
 
 LayerTestResult<uint8_t,2> SimpleSoftmaxUint8Test(
@@ -1676,7 +1971,9 @@
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple3dSoftmaxTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, beta);
+    Simple3dSoftmaxOutputData data;
+    return Simple3dSoftmaxTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, beta,
+                                                                     data.inputShape, data.outputData, data.inputData);
 }
 
 LayerTestResult<uint8_t,4> Simple4dSoftmaxUint8Test(
@@ -1684,7 +1981,10 @@
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple4dSoftmaxTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, beta);
+    Simple4dSoftmaxData data;
+
+    return Simple4dSoftmaxTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, beta,
+                                                                     data.inputShape, data.outputData, data.inputData);
 }
 
 LayerTestResult<int16_t,2> SimpleSoftmaxUint16Test(
@@ -1700,7 +2000,9 @@
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple3dSoftmaxTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, beta);
+    Simple3dSoftmaxOutputData data;
+    return Simple3dSoftmaxTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, beta,
+                                                                     data.inputShape, data.outputData, data.inputData);
 }
 
 LayerTestResult<int16_t,4> Simple4dSoftmaxUint16Test(
@@ -1708,7 +2010,10 @@
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple4dSoftmaxTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, beta);
+    Simple4dSoftmaxData data;
+
+    return Simple4dSoftmaxTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, beta,
+                                                                     data.inputShape, data.outputData, data.inputData);
 }
 
 LayerTestResult<float,4> CompareNormalizationTest(
diff --git a/src/backends/backendsCommon/test/LayerTests.hpp b/src/backends/backendsCommon/test/LayerTests.hpp
index d99e3b4..913c3a6 100644
--- a/src/backends/backendsCommon/test/LayerTests.hpp
+++ b/src/backends/backendsCommon/test/LayerTests.hpp
@@ -472,16 +472,34 @@
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
     float beta);
 
+LayerTestResult<float, 2> SimpleAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis);
+
 LayerTestResult<float, 3> Simple3dSoftmaxTest(
         armnn::IWorkloadFactory& workloadFactory,
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta);
 
+LayerTestResult<float, 3> Simple3dAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis);
+
 LayerTestResult<float, 4> Simple4dSoftmaxTest(
         armnn::IWorkloadFactory& workloadFactory,
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta);
 
+LayerTestResult<float, 4> Simple4dAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis);
+
 LayerTestResult<uint8_t, 2> SimpleSoftmaxUint8Test(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
diff --git a/src/backends/backendsCommon/test/SoftmaxTestImpl.hpp b/src/backends/backendsCommon/test/SoftmaxTestImpl.hpp
index 8081950..983a53b 100644
--- a/src/backends/backendsCommon/test/SoftmaxTestImpl.hpp
+++ b/src/backends/backendsCommon/test/SoftmaxTestImpl.hpp
@@ -25,7 +25,9 @@
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
     float beta,
     const armnn::TensorShape& inputShape,
-    const std::vector<float>& outputData)
+    const std::vector<float>& outputData,
+    const std::vector<float>& inputData,
+    int axis = 1)
 {
     using std::exp;
 
@@ -47,16 +49,14 @@
 
     // Each row is independently softmax'd.
     auto input = MakeTensor<T, n>(inputTensorInfo, std::vector<T>(
-        QuantizedVector<T>(qScale, qOffset, {
-            0.f, 1.f, 0.f, 0.f,
-            .5f, 0.f, 0.f, 0.f,
-        })));
+        QuantizedVector<T>(qScale, qOffset, inputData)));
 
     std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
     std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
 
     armnn::SoftmaxQueueDescriptor data;
     data.m_Parameters.m_Beta = beta;
+    data.m_Parameters.m_Axis = axis;
 
     armnn::WorkloadInfo info;
     AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
@@ -100,33 +100,98 @@
     const std::vector<float> outputData = { x0[0] / sum0, x0[1] / sum0, x0[2] / sum0, x0[3] / sum0,
                                             x1[0] / sum1, x1[1] / sum1, x1[2] / sum1, x1[3] / sum1 };
 
-    return SimpleSoftmaxBaseTestImpl<ArmnnType, 2>(workloadFactory, memoryManager, beta, inputShape, outputData);
+    const std::vector<float> inputData =
+            {
+                0.f, 1.f, 0.f, 0.f,
+                .5f, 0.f, 0.f, 0.f,
+            };
+
+    return SimpleSoftmaxBaseTestImpl<ArmnnType, 2>(workloadFactory, memoryManager, beta,
+                                                   inputShape, outputData, inputData);
+}
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> SimpleSoftmaxTestImpl(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis)
+{
+    armnn::TensorShape inputShape;
+    std::vector<float> inputData;
+    std::vector<float> outputData;
+    switch (axis)
+    {
+    case -2:
+    case 0:
+        {
+        inputShape = {5, 2};
+
+        inputData =
+                {
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
+                };
+
+        outputData =
+                {
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f
+                };
+        break;
+        }
+    case -1:
+    case 1:
+        {
+        inputShape = {2, 5};
+
+        inputData =
+                {
+                        17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
+                };
+
+        outputData =
+                {
+                        0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                        7.246299848982885e-08f,
+                        0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                        7.246299848982885e-08f
+                };
+        break;
+        }
+    }
+    return SimpleSoftmaxBaseTestImpl<ArmnnType, 2>(workloadFactory, memoryManager, beta,
+                                                   inputShape, outputData, inputData, axis);
 }
 
 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T, 3> Simple3dSoftmaxTestImpl(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
-    float beta)
+    float beta,
+    const armnn::TensorShape& inputShape,
+    const std::vector<float>& outputData,
+    const std::vector<float>& inputData,
+    int axis = 1)
 {
-    const armnn::TensorShape inputShape{ 1, 8, 1 };
-    const std::vector<float> outputData = { 0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
-                                            0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f };
-
-    return SimpleSoftmaxBaseTestImpl<ArmnnType, 3>(workloadFactory, memoryManager, beta, inputShape, outputData);
+    return SimpleSoftmaxBaseTestImpl<ArmnnType, 3>(workloadFactory, memoryManager, beta,
+                                                   inputShape, outputData, inputData, axis);
 }
 
 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T, 4> Simple4dSoftmaxTestImpl(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
-    float beta)
+    float beta,
+    const armnn::TensorShape& inputShape,
+    const std::vector<float>& outputData,
+    const std::vector<float>& inputData,
+    int axis = 1)
 {
-    const armnn::TensorShape inputShape{ 1, 8, 1, 1 };
-    const std::vector<float> outputData = { 0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
-                                            0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f };
 
-    return SimpleSoftmaxBaseTestImpl<ArmnnType, 4>(workloadFactory, memoryManager, beta, inputShape, outputData);
+    return SimpleSoftmaxBaseTestImpl<ArmnnType, 4>(workloadFactory, memoryManager, beta,
+                                                   inputShape, outputData, inputData, axis);
 }
 
 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>