COMPMID-3442: Add support of negative axis in NESoftmaxLayer and reference code
Signed-off-by: Sheri Zhang <sheri.zhang@arm.com>
Change-Id: I285cc3b74ac0a45f0ad5830baed5237cea568f15
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3147
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp
index c429782..8af3847 100644
--- a/tests/validation/NEON/SoftmaxLayer.cpp
+++ b/tests/validation/NEON/SoftmaxLayer.cpp
@@ -97,9 +97,9 @@
framework::dataset::make("axis", { 1,
1,
1,
+ -1,
1,
- 1,
- 0,
+ -3,
})),
framework::dataset::make("Expected", { false, false, false, true, true, false })),
input_info, output_info, beta, axis, expected)
@@ -188,7 +188,7 @@
framework::dataset::make("DataType", DataType::QASYMM8)),
combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
framework::dataset::make("Beta", { 1.0f, 2.f }))),
- framework::dataset::make("Axis", { 1, 2, 3 })))
+ framework::dataset::make("Axis", { -1, 2, 3 })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
@@ -209,7 +209,7 @@
framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
framework::dataset::make("Beta", { 1.0f, 2.f }))),
- framework::dataset::make("Axis", { 1 })))
+ framework::dataset::make("Axis", { -1, 1 })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
@@ -218,7 +218,7 @@
framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
framework::dataset::make("Beta", { 1.0f, 2.f }))),
- framework::dataset::make("Axis", { 1, 2, 3 })))
+ framework::dataset::make("Axis", { -2, 2, 3 })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8_signed);