[COMPMID-386] Github: Support SoftmaxLayer on different number of dimensions?

Change-Id: I7422b977538ff29930a90f078badc2edee78af93
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/146638
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/tests/validation/CL/SoftmaxLayer.cpp b/tests/validation/CL/SoftmaxLayer.cpp
index 7dab626..c9ef35d 100644
--- a/tests/validation/CL/SoftmaxLayer.cpp
+++ b/tests/validation/CL/SoftmaxLayer.cpp
@@ -134,23 +134,26 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::SoftmaxLayerSmallShapes(),
-                                                                                                           framework::dataset::make("DataType", DataType::F16)),
-                                                                                                   framework::dataset::make("Beta", { 1.0f, 2.0f })))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SoftmaxLayerSmallShapes(),
+                                                                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                                                           framework::dataset::make("Beta", { 1.0f, 2.0f })),
+                                                                                                   framework::dataset::make("Axis", { 1, 2 })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
-                                                                                                               framework::dataset::make("DataType", DataType::F16)),
-                                                                                                       framework::dataset::make("Beta", { 1.0f, 2.0f })))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayerLargeShapes(),
+                                                                                                                       framework::dataset::make("DataType", DataType::F16)),
+                                                                                                               framework::dataset::make("Beta", { 1.0f, 2.0f })),
+                                                                                                       framework::dataset::make("Axis", { 1, 2 })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(Run4D, CLSoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayer4DShapes(),
-                                                                                                            framework::dataset::make("DataType", DataType::F16)),
-                                                                                                    framework::dataset::make("Beta", { 1.0f, 2.0f })))
+FIXTURE_DATA_TEST_CASE(Run4D, CLSoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayer4DShapes(),
+                                                                                                                    framework::dataset::make("DataType", DataType::F16)),
+                                                                                                            framework::dataset::make("Beta", { 1.0f, 2.0f })),
+                                                                                                    framework::dataset::make("Axis", { 1, 2, 3 })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
@@ -158,23 +161,26 @@
 TEST_SUITE_END()
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(datasets::SoftmaxLayerSmallShapes(),
-                                                                                                            framework::dataset::make("DataType", DataType::F32)),
-                                                                                                    framework::dataset::make("Beta", { 1.0f, 2.0f })))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SoftmaxLayerSmallShapes(),
+                                                                                                                    framework::dataset::make("DataType", DataType::F32)),
+                                                                                                            framework::dataset::make("Beta", { 1.0f, 2.0f })),
+                                                                                                    framework::dataset::make("Axis", { 1, 2 })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
-                                                                                                                framework::dataset::make("DataType", DataType::F32)),
-                                                                                                        framework::dataset::make("Beta", { 1.0f, 2.0f })))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayerLargeShapes(),
+                                                                                                                        framework::dataset::make("DataType", DataType::F32)),
+                                                                                                                framework::dataset::make("Beta", { 1.0f, 2.0f })),
+                                                                                                        framework::dataset::make("Axis", { 1, 2 })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(Run4D, CLSoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayer4DShapes(),
-                                                                                                             framework::dataset::make("DataType", DataType::F32)),
-                                                                                                     framework::dataset::make("Beta", { 1.0f, 2.0f })))
+FIXTURE_DATA_TEST_CASE(Run4D, CLSoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayer4DShapes(),
+                                                                                                                     framework::dataset::make("DataType", DataType::F32)),
+                                                                                                             framework::dataset::make("Beta", { 1.0f, 2.0f })),
+                                                                                                     framework::dataset::make("Axis", { 1, 2, 3 })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -187,26 +193,29 @@
 
 TEST_SUITE(Quantized)
 TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SoftmaxLayerSmallShapes(),
                                                                                                                        framework::dataset::make("DataType", DataType::QASYMM8)),
-                                                                                                               combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
-                                                                                                                       framework::dataset::make("Beta", { 1.0f, 2.f }))))
+                                                                                                                       combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
+                                                                                                                               framework::dataset::make("Beta", { 1.0f, 2.f }))),
+                                                                                                               framework::dataset::make("Axis", { 1, 2 })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_qasymm8);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayerLargeShapes(),
                                                                                                                    framework::dataset::make("DataType", DataType::QASYMM8)),
                                                                                                                    combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
-                                                                                                                           framework::dataset::make("Beta", { 1.0f, 2.0f }))))
+                                                                                                                           framework::dataset::make("Beta", { 1.0f, 2.0f }))),
+                                                                                                                   framework::dataset::make("Axis", { 1 })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_qasymm8);
 }
-FIXTURE_DATA_TEST_CASE(Run4D, CLSoftmaxLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayer4DShapes(),
+FIXTURE_DATA_TEST_CASE(Run4D, CLSoftmaxLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayer4DShapes(),
                                                                                                                         framework::dataset::make("DataType", DataType::QASYMM8)),
-                                                                                                                combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
-                                                                                                                        framework::dataset::make("Beta", { 1.0f, 2.0f }))))
+                                                                                                                        combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
+                                                                                                                                framework::dataset::make("Beta", { 1.0f, 2.0f }))),
+                                                                                                                framework::dataset::make("Axis", { 1, 2, 3 })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_qasymm8);
diff --git a/tests/validation/GLES_COMPUTE/SoftmaxLayer.cpp b/tests/validation/GLES_COMPUTE/SoftmaxLayer.cpp
index abc277a..3b55717 100644
--- a/tests/validation/GLES_COMPUTE/SoftmaxLayer.cpp
+++ b/tests/validation/GLES_COMPUTE/SoftmaxLayer.cpp
@@ -86,16 +86,18 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, GCSoftmaxLayerFixture<half_float::half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+FIXTURE_DATA_TEST_CASE(RunSmall, GCSoftmaxLayerFixture<half_float::half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SoftmaxLayerSmallShapes(),
                                                                                                                      framework::dataset::make("DataType", DataType::F16)),
-                                                                                                                     framework::dataset::make("Beta", 1.0f)))
+                                                                                                                     framework::dataset::make("Beta", 1.0f)),
+                                                                                                                     framework::dataset::make("Axis", 1)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, GCSoftmaxLayerFixture<half_float::half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
+FIXTURE_DATA_TEST_CASE(RunLarge, GCSoftmaxLayerFixture<half_float::half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayerLargeShapes(),
                                                                                                                    framework::dataset::make("DataType", DataType::F16)),
-                                                                                                                   framework::dataset::make("Beta", 1.0f)))
+                                                                                                                   framework::dataset::make("Beta", 1.0f)),
+                                                                                                                   framework::dataset::make("Axis", 1)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_f16);
@@ -103,16 +105,18 @@
 TEST_SUITE_END()
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, GCSoftmaxLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+FIXTURE_DATA_TEST_CASE(RunSmall, GCSoftmaxLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SoftmaxLayerSmallShapes(),
                                                                                                                   framework::dataset::make("DataType", DataType::F32)),
-                                                                                                          framework::dataset::make("Beta", 1.0f)))
+                                                                                                                  framework::dataset::make("Beta", 1.0f)),
+                                                                                                          framework::dataset::make("Axis", 1)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, GCSoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
-                                                                                                                framework::dataset::make("DataType", DataType::F32)),
-                                                                                                        framework::dataset::make("Beta", 1.0f)))
+FIXTURE_DATA_TEST_CASE(RunLarge, GCSoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayerLargeShapes(),
+                                                                                                                        framework::dataset::make("DataType", DataType::F32)),
+                                                                                                                framework::dataset::make("Beta", 1.0f)),
+                                                                                                        framework::dataset::make("Axis", 1)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_f32);
diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp
index a5d6344..21c77e7 100644
--- a/tests/validation/NEON/SoftmaxLayer.cpp
+++ b/tests/validation/NEON/SoftmaxLayer.cpp
@@ -118,16 +118,18 @@
 TEST_SUITE(Float)
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SoftmaxLayerSmallShapes(),
                                                                                                                  framework::dataset::make("DataType", DataType::F16)),
-                                                                                                         framework::dataset::make("Beta", { 1.0f, 2.0f })))
+                                                                                                                 framework::dataset::make("Beta", { 1.0f, 2.0f })),
+                                                                                                         framework::dataset::make("Axis", { 1 })))
 {
     // Validate output
     validate(Accessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerSmallShapes(),
-                                                                                                               framework::dataset::make("DataType", DataType::F16)),
-                                                                                                       framework::dataset::make("Beta", { 1.0f, 2.0f })))
+FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayerLargeShapes(),
+                                                                                                                       framework::dataset::make("DataType", DataType::F16)),
+                                                                                                               framework::dataset::make("Beta", { 1.0f, 2.0f })),
+                                                                                                       framework::dataset::make("Axis", { 1 })))
 {
     // Validate output
     validate(Accessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
@@ -136,16 +138,18 @@
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SoftmaxLayerSmallShapes(),
                                                                                                                   framework::dataset::make("DataType", DataType::F32)),
-                                                                                                          framework::dataset::make("Beta", { 1.0f, 2.0f })))
+                                                                                                                  framework::dataset::make("Beta", { 1.0f, 2.0f })),
+                                                                                                          framework::dataset::make("Axis", { 1 })))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
-                                                                                                                framework::dataset::make("DataType", DataType::F32)),
-                                                                                                        framework::dataset::make("Beta", { 1.0f, 2.0f })))
+FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayerLargeShapes(),
+                                                                                                                        framework::dataset::make("DataType", DataType::F32)),
+                                                                                                                framework::dataset::make("Beta", { 1.0f, 2.0f })),
+                                                                                                        framework::dataset::make("Axis", { 1 })))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f32);
@@ -158,22 +162,25 @@
 
 TEST_SUITE(Quantized)
 TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SoftmaxLayerSmallShapes(),
                                                                                                                        framework::dataset::make("DataType", DataType::QASYMM8)),
-                                                                                                               combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
-                                                                                                                       framework::dataset::make("Beta", { 1.0f, 2.0f }))))
+                                                                                                                       combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
+                                                                                                                               framework::dataset::make("Beta", { 1.0f, 2.f }))),
+                                                                                                               framework::dataset::make("Axis", { 1 })))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_qasymm8);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
+FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayerLargeShapes(),
                                                                                                                    framework::dataset::make("DataType", DataType::QASYMM8)),
                                                                                                                    combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
-                                                                                                                           framework::dataset::make("Beta", { 1.0f, 2.0f }))))
+                                                                                                                           framework::dataset::make("Beta", { 1.0f, 2.0f }))),
+                                                                                                                   framework::dataset::make("Axis", { 1 })))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_qasymm8);
 }
+
 TEST_SUITE_END()
 TEST_SUITE_END()
 
diff --git a/tests/validation/fixtures/SoftmaxLayerFixture.h b/tests/validation/fixtures/SoftmaxLayerFixture.h
index 99c0710..e39ee74 100644
--- a/tests/validation/fixtures/SoftmaxLayerFixture.h
+++ b/tests/validation/fixtures/SoftmaxLayerFixture.h
@@ -47,12 +47,12 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, float beta)
+    void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, float beta, size_t axis)
     {
         _quantization_info = quantization_info;
 
-        _target    = compute_target(shape, data_type, quantization_info, beta);
-        _reference = compute_reference(shape, data_type, quantization_info, beta);
+        _target    = compute_target(shape, data_type, quantization_info, beta, axis);
+        _reference = compute_reference(shape, data_type, quantization_info, beta, axis);
     }
 
 protected:
@@ -72,7 +72,7 @@
     }
 
     TensorType compute_target(const TensorShape &shape, DataType data_type,
-                              QuantizationInfo quantization_info, float beta)
+                              QuantizationInfo quantization_info, float beta, size_t axis)
     {
         // Create tensors
         TensorType src = create_tensor<TensorType>(shape, data_type, 1, quantization_info);
@@ -80,7 +80,7 @@
 
         // Create and configure function
         FunctionType smx_layer;
-        smx_layer.configure(&src, &dst, beta);
+        smx_layer.configure(&src, &dst, beta, axis);
 
         ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
         ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -102,7 +102,7 @@
     }
 
     SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type,
-                                      QuantizationInfo quantization_info, float beta)
+                                      QuantizationInfo quantization_info, float beta, size_t axis)
     {
         // Create reference
         SimpleTensor<T> src{ shape, data_type, 1, quantization_info };
@@ -110,7 +110,7 @@
         // Fill reference
         fill(src);
 
-        return reference::softmax_layer<T>(src, beta);
+        return reference::softmax_layer<T>(src, beta, axis);
     }
 
     TensorType       _target{};
@@ -123,12 +123,13 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, DataType data_type, float beta)
+    void setup(TensorShape shape, DataType data_type, float beta, size_t axis)
     {
         SoftmaxValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
                                                                                           data_type,
                                                                                           QuantizationInfo(),
-                                                                                          beta);
+                                                                                          beta,
+                                                                                          axis);
     }
 };
 
@@ -137,12 +138,13 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, float beta)
+    void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, float beta, size_t axis)
     {
         SoftmaxValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
                                                                                           data_type,
                                                                                           quantization_info,
-                                                                                          beta);
+                                                                                          beta,
+                                                                                          axis);
     }
 };
 } // namespace validation
diff --git a/tests/validation/reference/SoftmaxLayer.cpp b/tests/validation/reference/SoftmaxLayer.cpp
index 7f2c36e..f1b94c0 100644
--- a/tests/validation/reference/SoftmaxLayer.cpp
+++ b/tests/validation/reference/SoftmaxLayer.cpp
@@ -34,18 +34,26 @@
 namespace reference
 {
 template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
-SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta)
+SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta, size_t axis)
 {
     // Create reference
     SimpleTensor<T> dst{ src.shape(), src.data_type(), 1 };
 
-    const bool is_4D_input = (src.shape().num_dimensions() > 2);
+    // Compute reference. Lower dims are the collapsing of the first axis
+    // dimensions (i.e., the flattened dimension of each batch). The upper dims are
+    // instead the batches we want to normalize
 
-    // Compute reference. Lower dims are
-    // - the number of columns for the 2D case
-    // - the collapsing of the first three dimensions (i.e., the flattened dimension of each batch) in the 4D case
-    const int lower_dims = (is_4D_input ? src.shape()[2] * src.shape()[1] * src.shape()[0] : src.shape()[0]);
-    const int upper_dims = src.num_elements() / lower_dims;
+    int lower_dims = 1;
+    for(size_t i = 0; i < axis; i++)
+    {
+        lower_dims *= src.shape()[i];
+    }
+
+    int upper_dims = 1;
+    for(size_t i = axis; i < TensorShape::num_max_dimensions; i++)
+    {
+        upper_dims *= src.shape()[i];
+    }
 
     for(int r = 0; r < upper_dims; ++r)
     {
@@ -75,20 +83,20 @@
 }
 
 template <typename T, typename std::enable_if<std::is_same<T, uint8_t>::value, int>::type>
-SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta)
+SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta, size_t axis)
 {
     // Note: Output quantization info should always have scale = 1/256 and offset = 0
     const QuantizationInfo output_quantization_info = QuantizationInfo(1.f / 256, 0);
 
     SimpleTensor<float> src_tmp = convert_from_asymmetric(src);
-    SimpleTensor<float> dst_tmp = softmax_layer<float>(src_tmp, beta);
+    SimpleTensor<float> dst_tmp = softmax_layer<float>(src_tmp, beta, axis);
     SimpleTensor<T>     dst     = convert_to_asymmetric(dst_tmp, output_quantization_info);
     return dst;
 }
 
-template SimpleTensor<float> softmax_layer(const SimpleTensor<float> &src, float beta);
-template SimpleTensor<half> softmax_layer(const SimpleTensor<half> &src, float beta);
-template SimpleTensor<uint8_t> softmax_layer(const SimpleTensor<uint8_t> &src, float beta);
+template SimpleTensor<float> softmax_layer(const SimpleTensor<float> &src, float beta, size_t axis);
+template SimpleTensor<half> softmax_layer(const SimpleTensor<half> &src, float beta, size_t axis);
+template SimpleTensor<uint8_t> softmax_layer(const SimpleTensor<uint8_t> &src, float beta, size_t axis);
 } // namespace reference
 } // namespace validation
 } // namespace test
diff --git a/tests/validation/reference/SoftmaxLayer.h b/tests/validation/reference/SoftmaxLayer.h
index 21dca1e..d21ca2b 100644
--- a/tests/validation/reference/SoftmaxLayer.h
+++ b/tests/validation/reference/SoftmaxLayer.h
@@ -36,10 +36,10 @@
 namespace reference
 {
 template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
-SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta);
+SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta, size_t axis = 1);
 
 template <typename T, typename std::enable_if<std::is_same<T, uint8_t>::value, int>::type = 0>
-SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta);
+SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta, size_t axis = 1);
 } // namespace reference
 } // namespace validation
 } // namespace test