COMPMID-617: Add validate support for NEON PixelWiseMultiplication

Change-Id: Ie81a4d667146315fed7668cf2ca752d3bf49b0ab
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/111013
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index a2f3cff..d765966 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -54,6 +54,68 @@
 const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant);
 const float32x4_t positive_round_f32q    = vdupq_n_f32(0.5f);
 
+inline Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
+{
+    ARM_COMPUTE_UNUSED(overflow_policy);
+    ARM_COMPUTE_UNUSED(rounding_policy);
+
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
+                                    "Output can only be U8 if both inputs are U8");
+
+    if(is_data_type_fixed_point(input1->data_type()) || is_data_type_fixed_point(input2->data_type()) || is_data_type_fixed_point(output->data_type()))
+    {
+        // Check that all data types are the same and all fixed-point positions are the same
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
+        // Check if scale is representable in fixed-point with the provided settings
+        ARM_COMPUTE_RETURN_ERROR_ON_VALUE_NOT_REPRESENTABLE_IN_FIXED_POINT(scale, input1);
+    }
+
+    if(std::abs(scale - scale255_constant) < 0.00001f)
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
+    }
+    else
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
+
+        int         exponent            = 0;
+        const float normalized_mantissa = std::frexp(scale, &exponent);
+
+        // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
+        // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
+        // Moreover, it will be negative as we deal with 1/2^n
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(!((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1)), "Scale value not supported (Should be 1/(2^n) or 1/255");
+    }
+
+    return Error{};
+}
+
+inline std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
+{
+    constexpr unsigned int num_elems_processed_per_iteration = 16;
+
+    // Configure kernel window
+    Window                 win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration));
+    AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+
+    bool window_changed = update_window_and_padding(win,
+                                                    AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration),
+                                                    AccessWindowHorizontal(input2, 0, num_elems_processed_per_iteration),
+                                                    output_access);
+
+    ValidRegion valid_region = intersect_valid_regions(input1->valid_region(),
+                                                       input2->valid_region());
+
+    output_access.set_valid_region(win, valid_region);
+
+    Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{};
+    return std::make_pair(err, win);
+}
+
 /* Scales a given vector by 1/255.
  *
  * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats.
@@ -443,6 +505,7 @@
 
 void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
 {
+    ARM_COMPUTE_UNUSED(rounding_policy);
     ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
 
     // Auto initialize output if not initialized
@@ -468,19 +531,7 @@
         }
     }
 
-    ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
-                             "Output can only be U8 if both inputs are U8");
-    if(is_data_type_fixed_point(input1->info()->data_type()) || is_data_type_fixed_point(input2->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type()))
-    {
-        // Check that all data types are the same and all fixed-point positions are the same
-        ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
-        // Check if scale is representable in fixed-point with the provided settings
-        ARM_COMPUTE_ERROR_ON_VALUE_NOT_REPRESENTABLE_IN_FIXED_POINT(scale, input1);
-    }
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(), scale, overflow_policy, rounding_policy));
 
     _input1         = input1;
     _input2         = input2;
@@ -495,32 +546,17 @@
     // Check and validate scaling factor
     if(std::abs(scale - scale255_constant) < 0.00001f)
     {
-        ARM_COMPUTE_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
-        ARM_COMPUTE_UNUSED(rounding_policy);
-
         is_scale_255 = true;
     }
     else
     {
-        ARM_COMPUTE_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
-        ARM_COMPUTE_UNUSED(rounding_policy);
+        int exponent = 0;
 
-        int         exponent            = 0;
-        const float normalized_mantissa = std::frexp(scale, &exponent);
+        std::frexp(scale, &exponent);
 
-        // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
-        // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
-        // Moreover, it will be negative as we deal with 1/2^n
-        if((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1))
-        {
-            // Store the positive exponent. We know that we compute 1/2^n
-            // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
-            _scale_exponent = std::abs(exponent - 1);
-        }
-        else
-        {
-            ARM_COMPUTE_ERROR("Scale value not supported (Should be 1/(2^n) or 1/255");
-        }
+        // Store the positive exponent. We know that we compute 1/2^n
+        // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
+        _scale_exponent = std::abs(exponent - 1);
     }
 
     const DataType dt_input1 = input1->info()->data_type();
@@ -620,23 +656,19 @@
         ARM_COMPUTE_ERROR("You called with the wrong img formats");
     }
 
-    constexpr unsigned int num_elems_processed_per_iteration = 16;
-
     // Configure kernel window
-    Window                 win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration));
-    AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
+    auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
+    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+    INEKernel::configure(win_config.second);
+}
 
-    update_window_and_padding(win,
-                              AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration),
-                              AccessWindowHorizontal(input2->info(), 0, num_elems_processed_per_iteration),
-                              output_access);
+Error NEPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy,
+                                                RoundingPolicy rounding_policy)
+{
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
 
-    ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(),
-                                                       input2->info()->valid_region());
-
-    output_access.set_valid_region(win, valid_region);
-
-    INEKernel::configure(win);
+    return Error{};
 }
 
 void NEPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo &info)
diff --git a/src/runtime/NEON/functions/NEPixelWiseMultiplication.cpp b/src/runtime/NEON/functions/NEPixelWiseMultiplication.cpp
index 2e2ea11..9cd9d59 100644
--- a/src/runtime/NEON/functions/NEPixelWiseMultiplication.cpp
+++ b/src/runtime/NEON/functions/NEPixelWiseMultiplication.cpp
@@ -36,3 +36,7 @@
     k->configure(input1, input2, output, scale, overflow_policy, rounding_policy);
     _kernel = std::move(k);
 }
+Error NEPixelWiseMultiplication::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
+{
+    return NEPixelWiseMultiplicationKernel::validate(input1, input2, output, scale, overflow_policy, rounding_policy);
+}