COMPMID-3871: Create BatchNormalization SVE/SVE2
1. Decouple data type for NHWC
2. Add NHWC SVE support for BachNormalization
Signed-off-by: Sheri Zhang <sheri.zhang@arm.com>
Change-Id: I0383b969b555b429d9acebb4efa17ecba9429ea7
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4755
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
diff --git a/tests/validation/NEON/BatchNormalizationLayer.cpp b/tests/validation/NEON/BatchNormalizationLayer.cpp
index 067c5bb..b24357f 100644
--- a/tests/validation/NEON/BatchNormalizationLayer.cpp
+++ b/tests/validation/NEON/BatchNormalizationLayer.cpp
@@ -51,8 +51,10 @@
RelativeTolerance<float> rel_tolerance_f32(0.05f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
constexpr AbsoluteTolerance<float> abs_tolerance_f32(0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-constexpr AbsoluteTolerance<float> tolerance_f16(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+RelativeTolerance<float> rel_tolerance_f16(0.05f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
+constexpr AbsoluteTolerance<float> abs_tolerance_f16(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
const auto act_infos = framework::dataset::make("ActivationInfo",
{
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
@@ -148,7 +150,7 @@
framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
{
// Validate output
- validate(Accessor(_target), _reference, tolerance_f16, 0);
+ validate(Accessor(_target), _reference, abs_tolerance_f16, 0);
}
FIXTURE_DATA_TEST_CASE(RandomLarge, NEBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::LargeRandomBatchNormalizationLayerDataset(),
@@ -159,7 +161,7 @@
framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
{
// Validate output
- validate(Accessor(_target), _reference, tolerance_f16, 0);
+ validate(Accessor(_target), _reference, abs_tolerance_f16, 0);
}
TEST_SUITE_END() // FP16
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */