COMPMID-617: Add validate support for NEON ArithmeticLayer

Change-Id: I8b58359487194f4cbf7452df4aea92523b5745bf
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/111351
Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
index 8e55994..6452393 100644
--- a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
@@ -355,6 +355,57 @@
     },
     input1, input2, output);
 }
+
+inline Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
+{
+    ARM_COMPUTE_UNUSED(policy);
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+
+    if(is_data_type_fixed_point(input1->data_type()) || is_data_type_fixed_point(input2->data_type()) || is_data_type_fixed_point(output->data_type()))
+    {
+        // Check that all data types are the same and all fixed-point positions are the same
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
+    }
+
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+        !(input1->data_type() == DataType::QS8 && input2->data_type() == DataType::QS8 && output->data_type() == DataType::QS8)
+        && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::U8)
+        && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
+        && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
+        && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
+        && !(input1->data_type() == DataType::QS16 && input2->data_type() == DataType::QS16 && output->data_type() == DataType::QS16)
+        && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
+        && !(input1->data_type() == DataType::F32 && input2->data_type() == DataType::F32 && output->data_type() == DataType::F32)
+        && !(input1->data_type() == DataType::F16 && input2->data_type() == DataType::F16 && output->data_type() == DataType::F16),
+        "You called addition with the wrong image formats");
+
+    return Error{};
+}
+
+inline std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
+{
+    constexpr unsigned int num_elems_processed_per_iteration = 16;
+
+    // Configure kernel window
+    Window                 win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration));
+    AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+
+    bool window_changed = update_window_and_padding(win,
+                                                    AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration),
+                                                    AccessWindowHorizontal(input2, 0, num_elems_processed_per_iteration),
+                                                    output_access);
+
+    ValidRegion valid_region = intersect_valid_regions(input1->valid_region(),
+                                                       input2->valid_region());
+
+    output_access.set_valid_region(win, valid_region);
+
+    Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{};
+    return std::make_pair(err, win);
+}
 } // namespace
 
 NEArithmeticAdditionKernel::NEArithmeticAdditionKernel()
@@ -384,17 +435,7 @@
         }
     }
 
-    ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
-                             "Output can only be U8 if both inputs are U8");
-    if(is_data_type_fixed_point(input1->info()->data_type()) || is_data_type_fixed_point(input2->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type()))
-    {
-        // Check that all data types are the same and all fixed-point positions are the same
-        ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
-    }
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(), policy));
 
     static std::map<std::string, AddFunction *> map_function =
     {
@@ -416,7 +457,6 @@
         { "add_saturate_F32_F32_F32", &add_F32_F32_F32 },
         { "add_wrap_F16_F16_F16", &add_F16_F16_F16 },
         { "add_saturate_F16_F16_F16", &add_F16_F16_F16 },
-
     };
 
     _input1 = input1;
@@ -435,28 +475,19 @@
     {
         _func = it->second;
     }
-    else
-    {
-        ARM_COMPUTE_ERROR("You called arithmetic addition with the wrong tensor data type");
-    }
-
-    constexpr unsigned int num_elems_processed_per_iteration = 16;
 
     // Configure kernel window
-    Window                 win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration));
-    AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
+    auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
+    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+    INEKernel::configure(win_config.second);
+}
 
-    update_window_and_padding(win,
-                              AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration),
-                              AccessWindowHorizontal(input2->info(), 0, num_elems_processed_per_iteration),
-                              output_access);
+Error NEArithmeticAdditionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
+{
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, policy));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
 
-    ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(),
-                                                       input2->info()->valid_region());
-
-    output_access.set_valid_region(win, valid_region);
-
-    INEKernel::configure(win);
+    return Error{};
 }
 
 void NEArithmeticAdditionKernel::run(const Window &window, const ThreadInfo &info)