Fix overflow in NEMeanStdDevNormalizationKernel

* Perform final sum in fp32 to avoid overflow

* Resolves ARMCL-1128

Change-Id: I89799baf81045697f7bc44017fcb6a440635caff
Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11311
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp b/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
index 6470f39..344b9df 100644
--- a/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -66,26 +66,20 @@
                 sum_sq_vec       = vaddq_f32(sum_sq_vec, vmulq_f32(dh, dh));
             }
 
-            float16x4_t sum_carry_res = vpadd_f16(vget_high_f16(sum_vec), vget_low_f16(sum_vec));
-            sum_carry_res             = vpadd_f16(sum_carry_res, sum_carry_res);
-            sum_carry_res             = vpadd_f16(sum_carry_res, sum_carry_res);
-
-            float32x4_t sum_sq_carry_res = vpaddq_f32(sum_sq_vec, sum_sq_vec);
-            sum_sq_carry_res             = vpaddq_f32(sum_sq_carry_res, sum_sq_carry_res);
-
-            float16_t sum    = vget_lane_f16(sum_carry_res, 0);
-            float     sum_sq = vgetq_lane_f32(sum_sq_carry_res, 0);
+            float32x4_t sum_carry_res =
+                vpaddq_f32(vcvt_f32_f16(vget_high_f16(sum_vec)), vcvt_f32_f16(vget_low_f16(sum_vec)));
+            float sum    = vaddvq_f32(sum_carry_res);
+            float sum_sq = vaddvq_f32(sum_sq_vec);
 
             // Compute left-over elements
             for (; x < window_end_x; ++x)
             {
-                float16_t data = *(in_ptr + x);
-                sum += data;
-                float fdata = static_cast<float>(data);
+                const float fdata = static_cast<float>(*(in_ptr + x));
+                sum += fdata;
                 sum_sq += fdata * fdata;
             }
 
-            float16_t mean       = sum / input->info()->dimension(0);
+            float16_t mean       = static_cast<float16_t>(sum / input->info()->dimension(0));
             float     var        = (sum_sq / input->info()->dimension(0)) - (mean * mean);
             float16_t stddev_inv = static_cast<float16_t>(1.f / sqrt(var + epsilon));