IVGCVSW-3536 Add Axis parameter to reference Softmax implementation

 * Add Axis parameter to Softmax Descriptor
 * Add new reference implementation for Softmax using Axis parameter
 * Add unit tests to cover each Axis

Change-Id: Iafac2275d2212337456f2b1b56b0f76f77fb9543
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
diff --git a/src/backends/backendsCommon/test/LayerTests.cpp b/src/backends/backendsCommon/test/LayerTests.cpp
index d6e0e87..b40a3f5 100644
--- a/src/backends/backendsCommon/test/LayerTests.cpp
+++ b/src/backends/backendsCommon/test/LayerTests.cpp
@@ -77,6 +77,36 @@
 // 2-channel bias used by a number of Conv2d tests.
 static std::vector<float> Bias2({0, 2});
 
+struct Simple3dSoftmaxOutputData
+{
+    const std::vector<float> outputData =
+            {
+                0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
+                0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f
+            };
+
+    const armnn::TensorShape inputShape{ 1, 8, 1 };
+
+    const std::vector<float> inputData =
+            {
+                    0.f, 1.f, 0.f, 0.f,
+                    .5f, 0.f, 0.f, 0.f,
+            };
+};
+
+struct Simple4dSoftmaxData
+{
+    const armnn::TensorShape inputShape{ 1, 8, 1, 1 };
+
+    const std::vector<float> outputData = { 0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
+                                            0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f };
+    const std::vector<float> inputData =
+            {
+                    0.f, 1.f, 0.f, 0.f,
+                    .5f, 0.f, 0.f, 0.f
+            };
+};
+
 // Helper function that returns either Bias2 or an empty vector depending on whether bias is enabled.
 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 boost::multi_array<T, 1> GetBias2(bool biasEnabled, float qScale)
@@ -1647,12 +1677,117 @@
     return SimpleSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta);
 }
 
+LayerTestResult<float,2> SimpleAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis)
+{
+    return SimpleSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta, axis);
+}
+
 LayerTestResult<float,3> Simple3dSoftmaxTest(
         armnn::IWorkloadFactory& workloadFactory,
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta);
+    Simple3dSoftmaxOutputData data;
+    return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta,
+                                                             data.inputShape, data.outputData, data.inputData);
+}
+
+LayerTestResult<float,3> Simple3dAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis)
+{
+    armnn::TensorShape inputShape;
+    std::vector<float> inputData;
+    std::vector<float> outputData;
+    switch (axis)
+    {
+    case -3:
+    case 0:
+        {
+            inputShape = {5, 2, 2};
+
+            inputData =
+                    {
+                            17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
+
+                            15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
+                            0.087144312427294f,
+
+                            0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
+                            0.032058600957022f,
+                            0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f
+                    };
+            break;
+        }
+    case -2:
+    case 1:
+        {
+            inputShape = {2, 5, 2};
+
+            inputData =
+                    {
+                            17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
+
+                            17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                            0.087144312427294f,
+                            0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f,
+
+                            0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                            0.087144312427294f,
+                            0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f
+                    };
+        break;
+        }
+    case -1:
+    case 2:
+        {
+            inputShape = {2, 2, 5};
+
+            inputData =
+                    {
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f
+                    };
+            break;
+        }
+    }
+
+    return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta,
+                                                             inputShape, outputData, inputData, axis);
 }
 
 LayerTestResult<float,4> Simple4dSoftmaxTest(
@@ -1660,7 +1795,167 @@
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta);
+    Simple4dSoftmaxData data;
+    return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta, data.inputShape,
+                                                             data.outputData, data.inputData);
+}
+
+LayerTestResult<float,4> Simple4dAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis)
+{
+    armnn::TensorShape inputShape;
+    std::vector<float> inputData;
+    std::vector<float> outputData;
+    switch (axis)
+    {
+    case -4:
+    case 0:
+        {
+            inputShape = {5, 2, 2, 2};
+
+            inputData =
+                    {
+                            17.0f, -1.0f, 17.0f, -1.0f, 17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f,
+                            16.0f, -2.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f, 15.0f, -3.0f,
+                            15.0f, -3.0f, 15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 14.0f, -4.0f,
+                            14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f, 1.0f, -17.0f, 1.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
+                            0.643914213228014f,
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.236882800924671f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.236882800924671f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.087144312427294f, 0.087144312427294f, 0.087144312427294f,
+                            0.087144312427294f,
+
+                            0.087144312427294f, 0.087144312427294f, 0.087144312427294f, 0.087144312427294f,
+                            0.032058600957022f,
+                            0.032058600957022f, 0.032058600957022f, 0.032058600957022f, 0.032058600957022f,
+                            0.032058600957022f,
+                            0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f,
+                            7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f, 7.246299848982885e-08f
+                    };
+            break;
+        }
+    case -3:
+    case 1:
+        {
+            inputShape = {2, 5, 2, 2};
+
+            inputData =
+                    {
+                            17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
+                            15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f,
+                            17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
+                            15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
+                            0.087144312427294f,
+                            0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
+                            0.032058600957022f,
+                            0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f,
+
+
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
+                            0.087144312427294f,
+                            0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
+                            0.032058600957022f,
+                            0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f
+                    };
+            break;
+        }
+    case -2:
+    case 2:
+        {
+        inputShape = {2, 2, 5, 2};
+
+        inputData =
+                {
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
+                };
+
+        outputData =
+                {
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f,
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f,
+
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f,
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f
+                };
+        break;
+        }
+    case -1:
+    case 3:
+        {
+            inputShape = {2, 2, 2, 5};
+
+            inputData =
+                    {
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f
+                    };
+            break;
+        }
+    }
+
+    return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta, inputShape,
+                                                             outputData, inputData, axis);
 }
 
 LayerTestResult<uint8_t,2> SimpleSoftmaxUint8Test(
@@ -1676,7 +1971,9 @@
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple3dSoftmaxTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, beta);
+    Simple3dSoftmaxOutputData data;
+    return Simple3dSoftmaxTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, beta,
+                                                                     data.inputShape, data.outputData, data.inputData);
 }
 
 LayerTestResult<uint8_t,4> Simple4dSoftmaxUint8Test(
@@ -1684,7 +1981,10 @@
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple4dSoftmaxTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, beta);
+    Simple4dSoftmaxData data;
+
+    return Simple4dSoftmaxTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, beta,
+                                                                     data.inputShape, data.outputData, data.inputData);
 }
 
 LayerTestResult<int16_t,2> SimpleSoftmaxUint16Test(
@@ -1700,7 +2000,9 @@
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple3dSoftmaxTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, beta);
+    Simple3dSoftmaxOutputData data;
+    return Simple3dSoftmaxTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, beta,
+                                                                     data.inputShape, data.outputData, data.inputData);
 }
 
 LayerTestResult<int16_t,4> Simple4dSoftmaxUint16Test(
@@ -1708,7 +2010,10 @@
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple4dSoftmaxTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, beta);
+    Simple4dSoftmaxData data;
+
+    return Simple4dSoftmaxTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, beta,
+                                                                     data.inputShape, data.outputData, data.inputData);
 }
 
 LayerTestResult<float,4> CompareNormalizationTest(