IVGCVSW-631 Neon support for Softmax beta parameter (F32 only)

Change-Id: Ibf6f038b39f1a4e557f5d04feb08e3d5ef54e223
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/112019
Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/tests/validation/CL/SoftmaxLayer.cpp b/tests/validation/CL/SoftmaxLayer.cpp
index f43e680..bd70723 100644
--- a/tests/validation/CL/SoftmaxLayer.cpp
+++ b/tests/validation/CL/SoftmaxLayer.cpp
@@ -148,12 +148,16 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixture<half>, framework::DatasetMode::ALL, combine(datasets::SoftmaxLayerSmallShapes(), framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+                                                                                                           framework::dataset::make("DataType", DataType::F16)),
+                                                                                                   framework::dataset::make("Beta", { 1.0f, 2.0f })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::SoftmaxLayerLargeShapes(), framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
+                                                                                                               framework::dataset::make("DataType", DataType::F16)),
+                                                                                                       framework::dataset::make("Beta", { 1.0f, 2.0f })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
@@ -161,12 +165,16 @@
 TEST_SUITE_END()
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixture<float>, framework::DatasetMode::ALL, combine(datasets::SoftmaxLayerSmallShapes(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+                                                                                                            framework::dataset::make("DataType", DataType::F32)),
+                                                                                                    framework::dataset::make("Beta", { 1.0f, 2.0f })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::SoftmaxLayerLargeShapes(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
+                                                                                                                framework::dataset::make("DataType", DataType::F32)),
+                                                                                                        framework::dataset::make("Beta", { 1.0f, 2.0f })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -223,17 +231,17 @@
 TEST_SUITE(Quantized)
 TEST_SUITE(QASYMM8)
 FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(datasets::SoftmaxLayerSmallShapes(),
-                                                                                                                       framework::dataset::make("DataType",
-                                                                                                                               DataType::QASYMM8)),
-                                                                                                               framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) })))
+                                                                                                                       framework::dataset::make("DataType", DataType::QASYMM8)),
+                                                                                                               combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
+                                                                                                                       framework::dataset::make("Beta", { 1.0f, 2.f }))))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_qasymm8);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
-                                                                                                                   framework::dataset::make("DataType",
-                                                                                                                           DataType::QASYMM8)),
-                                                                                                                   framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) })))
+                                                                                                                   framework::dataset::make("DataType", DataType::QASYMM8)),
+                                                                                                                   combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
+                                                                                                                           framework::dataset::make("Beta", { 1.0f, 2.0f }))))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_qasymm8);
diff --git a/tests/validation/GLES_COMPUTE/SoftmaxLayer.cpp b/tests/validation/GLES_COMPUTE/SoftmaxLayer.cpp
index a2114a9..2c28141 100644
--- a/tests/validation/GLES_COMPUTE/SoftmaxLayer.cpp
+++ b/tests/validation/GLES_COMPUTE/SoftmaxLayer.cpp
@@ -57,7 +57,7 @@
 TEST_SUITE(GC)
 TEST_SUITE(SoftmaxLayer)
 
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(concat(datasets::SmallShapes(), datasets::LargeShapes()), CNNDataTypes), shape, data_type)
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(concat(datasets::SoftmaxLayerSmallShapes(), datasets::SoftmaxLayerLargeShapes()), CNNDataTypes), shape, data_type)
 {
     // Set fixed point position data type allowed
     const int fixed_point_position = is_data_type_fixed_point(data_type) ? 3 : 0;
@@ -89,12 +89,16 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, GCSoftmaxLayerFixture<half_float::half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunSmall, GCSoftmaxLayerFixture<half_float::half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+                                                                                                                     framework::dataset::make("DataType", DataType::F16)),
+                                                                                                                     framework::dataset::make("Beta", 1.0f)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, GCSoftmaxLayerFixture<half_float::half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, GCSoftmaxLayerFixture<half_float::half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
+                                                                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                                                                   framework::dataset::make("Beta", 1.0f)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_f16);
@@ -102,12 +106,16 @@
 TEST_SUITE_END()
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, GCSoftmaxLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, GCSoftmaxLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+                                                                                                                  framework::dataset::make("DataType", DataType::F32)),
+                                                                                                          framework::dataset::make("Beta", 1.0f)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, GCSoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, GCSoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
+                                                                                                                framework::dataset::make("DataType", DataType::F32)),
+                                                                                                        framework::dataset::make("Beta", 1.0f)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_f32);
diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp
index 9d1795e..1a303e1 100644
--- a/tests/validation/NEON/SoftmaxLayer.cpp
+++ b/tests/validation/NEON/SoftmaxLayer.cpp
@@ -65,7 +65,7 @@
 TEST_SUITE(NEON)
 TEST_SUITE(SoftmaxLayer)
 
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(concat(datasets::SmallShapes(), datasets::LargeShapes()), CNNDataTypes), shape, data_type)
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(concat(datasets::SoftmaxLayerSmallShapes(), datasets::SoftmaxLayerLargeShapes()), CNNDataTypes), shape, data_type)
 {
     // Set fixed point position data type allowed
     const int fixed_point_position = is_data_type_fixed_point(data_type) ? 3 : 0;
@@ -99,12 +99,16 @@
 TEST_SUITE(Float)
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+                                                                                                                 framework::dataset::make("DataType", DataType::F16)),
+                                                                                                         framework::dataset::make("Beta", { 1.0f, 2.0f })))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+                                                                                                               framework::dataset::make("DataType", DataType::F16)),
+                                                                                                       framework::dataset::make("Beta", { 1.0f, 2.0f })))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f16);
@@ -113,12 +117,16 @@
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SoftmaxLayerSmallShapes(),
+                                                                                                                  framework::dataset::make("DataType", DataType::F32)),
+                                                                                                          framework::dataset::make("Beta", { 1.0f, 2.0f })))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
+                                                                                                                framework::dataset::make("DataType", DataType::F32)),
+                                                                                                        framework::dataset::make("Beta", { 1.0f, 2.0f })))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f32);
@@ -132,14 +140,14 @@
 TEST_SUITE(Quantized)
 TEST_SUITE(QS8)
 // Testing for fixed point position [1,6) as reciprocal limits the maximum fixed point position to 5
-FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SoftmaxLayerSmallShapes(), framework::dataset::make("DataType",
                                                                                                                      DataType::QS8)),
                                                                                                                      framework::dataset::make("FractionalBits", 1, 6)))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_fixed_point);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(), framework::dataset::make("DataType",
                                                                                                                    DataType::QS8)),
                                                                                                                    framework::dataset::make("FractionalBits", 1, 6)))
 {
@@ -150,7 +158,7 @@
 
 TEST_SUITE(QS16)
 // Testing for fixed point position [1,14) as reciprocal limits the maximum fixed point position to 14
-FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SoftmaxLayerSmallShapes(),
                                                                                                                       framework::dataset::make("DataType",
                                                                                                                               DataType::QS16)),
                                                                                                                       framework::dataset::make("FractionalBits", 1, 14)))
@@ -158,7 +166,7 @@
     // Validate output
     validate(Accessor(_target), _reference, tolerance_fixed_point);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SoftmaxLayerLargeShapes(),
                                                                                                                     framework::dataset::make("DataType",
                                                                                                                             DataType::QS16)),
                                                                                                                     framework::dataset::make("FractionalBits", 1, 14)))
diff --git a/tests/validation/fixtures/SoftmaxLayerFixture.h b/tests/validation/fixtures/SoftmaxLayerFixture.h
index 3ffbc6a..c2ab2e2 100644
--- a/tests/validation/fixtures/SoftmaxLayerFixture.h
+++ b/tests/validation/fixtures/SoftmaxLayerFixture.h
@@ -47,13 +47,13 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, DataType data_type, int fractional_bits, QuantizationInfo quantization_info)
+    void setup(TensorShape shape, DataType data_type, int fractional_bits, QuantizationInfo quantization_info, float beta)
     {
         _fractional_bits   = fractional_bits;
         _quantization_info = quantization_info;
 
-        _target    = compute_target(shape, data_type, fractional_bits, quantization_info);
-        _reference = compute_reference(shape, data_type, fractional_bits, quantization_info);
+        _target    = compute_target(shape, data_type, fractional_bits, quantization_info, beta);
+        _reference = compute_reference(shape, data_type, fractional_bits, quantization_info, beta);
     }
 
 protected:
@@ -78,7 +78,8 @@
         }
     }
 
-    TensorType compute_target(const TensorShape &shape, DataType data_type, int fixed_point_position, QuantizationInfo quantization_info)
+    TensorType compute_target(const TensorShape &shape, DataType data_type, int fixed_point_position,
+                              QuantizationInfo quantization_info, float beta)
     {
         // Create tensors
         TensorType src = create_tensor<TensorType>(shape, data_type, 1, fixed_point_position, quantization_info);
@@ -86,7 +87,7 @@
 
         // Create and configure function
         FunctionType smx_layer;
-        smx_layer.configure(&src, &dst);
+        smx_layer.configure(&src, &dst, beta);
 
         ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
         ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -107,7 +108,8 @@
         return dst;
     }
 
-    SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, int fixed_point_position, QuantizationInfo quantization_info)
+    SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, int fixed_point_position,
+                                      QuantizationInfo quantization_info, float beta)
     {
         // Create reference
         SimpleTensor<T> src{ shape, data_type, 1, fixed_point_position, quantization_info };
@@ -115,7 +117,7 @@
         // Fill reference
         fill(src);
 
-        return reference::softmax_layer<T>(src);
+        return reference::softmax_layer<T>(src, beta);
     }
 
     TensorType       _target{};
@@ -129,9 +131,13 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, DataType data_type)
+    void setup(TensorShape shape, DataType data_type, float beta)
     {
-        SoftmaxValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, 0, QuantizationInfo());
+        SoftmaxValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
+                                                                                          data_type,
+                                                                                          0,
+                                                                                          QuantizationInfo(),
+                                                                                          beta);
     }
 };
 
@@ -142,7 +148,11 @@
     template <typename...>
     void setup(TensorShape shape, DataType data_type, int fixed_point_position)
     {
-        SoftmaxValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, fixed_point_position, QuantizationInfo());
+        SoftmaxValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
+                                                                                          data_type,
+                                                                                          fixed_point_position,
+                                                                                          QuantizationInfo(),
+                                                                                          1.0f);
     }
 };
 
@@ -151,9 +161,13 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info)
+    void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, float beta)
     {
-        SoftmaxValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, 0, quantization_info);
+        SoftmaxValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
+                                                                                          data_type,
+                                                                                          0,
+                                                                                          quantization_info,
+                                                                                          beta);
     }
 };
 } // namespace validation
diff --git a/tests/validation/reference/SoftmaxLayer.cpp b/tests/validation/reference/SoftmaxLayer.cpp
index 8e8cc1b..90b9b1f 100644
--- a/tests/validation/reference/SoftmaxLayer.cpp
+++ b/tests/validation/reference/SoftmaxLayer.cpp
@@ -35,7 +35,7 @@
 namespace reference
 {
 template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
-SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src)
+SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta)
 {
     // Create reference
     SimpleTensor<T> dst{ src.shape(), src.data_type(), 1, src.fixed_point_position() };
@@ -54,9 +54,9 @@
 
         // Regularize
         T sum(0.f);
-        std::transform(src_row_ptr, src_row_ptr + cols, dst_row_ptr, [&sum, max](T val)
+        std::transform(src_row_ptr, src_row_ptr + cols, dst_row_ptr, [&sum, max, beta](T val)
         {
-            const T res(std::exp(val - max));
+            const T res(std::exp((val - max) * beta));
             sum += res;
             return res;
         });
@@ -72,8 +72,10 @@
 }
 
 template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type>
-SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src)
+SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta)
 {
+    ARM_COMPUTE_UNUSED(beta);
+
     using namespace fixed_point_arithmetic;
 
     // Create reference
@@ -113,21 +115,21 @@
 }
 
 template <>
-SimpleTensor<uint8_t> softmax_layer<uint8_t>(const SimpleTensor<uint8_t> &src)
+SimpleTensor<uint8_t> softmax_layer<uint8_t>(const SimpleTensor<uint8_t> &src, float beta)
 {
     // Note: Output quantization info should always have scale = 1/256 and offset = 0
     const QuantizationInfo output_quantization_info = QuantizationInfo(1.f / 256, 0);
 
     SimpleTensor<float>   src_tmp = convert_from_asymmetric(src);
-    SimpleTensor<float>   dst_tmp = softmax_layer<float>(src_tmp);
+    SimpleTensor<float>   dst_tmp = softmax_layer<float>(src_tmp, beta);
     SimpleTensor<uint8_t> dst     = convert_to_asymmetric(dst_tmp, output_quantization_info);
     return dst;
 }
 
-template SimpleTensor<float> softmax_layer(const SimpleTensor<float> &src);
-template SimpleTensor<half> softmax_layer(const SimpleTensor<half> &src);
-template SimpleTensor<qint8_t> softmax_layer(const SimpleTensor<qint8_t> &src);
-template SimpleTensor<qint16_t> softmax_layer(const SimpleTensor<qint16_t> &src);
+template SimpleTensor<float> softmax_layer(const SimpleTensor<float> &src, float beta);
+template SimpleTensor<half> softmax_layer(const SimpleTensor<half> &src, float beta);
+template SimpleTensor<qint8_t> softmax_layer(const SimpleTensor<qint8_t> &src, float beta);
+template SimpleTensor<qint16_t> softmax_layer(const SimpleTensor<qint16_t> &src, float beta);
 } // namespace reference
 } // namespace validation
 } // namespace test
diff --git a/tests/validation/reference/SoftmaxLayer.h b/tests/validation/reference/SoftmaxLayer.h
index ab79bc4..a6d4c3b 100644
--- a/tests/validation/reference/SoftmaxLayer.h
+++ b/tests/validation/reference/SoftmaxLayer.h
@@ -36,10 +36,10 @@
 namespace reference
 {
 template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
-SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src);
+SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta);
 
 template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
-SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src);
+SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta);
 } // namespace reference
 } // namespace validation
 } // namespace test