COMPMID-1566: Add broadcast to CLArithmeticSubtraction

Change-Id: I05d21f9a92013ecfd1128d12cf1561cfd6e5c5e9
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/147983
Tested-by: bsgcomp <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
index 3c76548..ff8fb84 100644
--- a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
@@ -46,10 +46,12 @@
 
 namespace
 {
+constexpr unsigned int num_elems_processed_per_iteration = 16;
+
 void sub_wrap_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
 {
-    Iterator input1(in1, window);
-    Iterator input2(in2, window);
+    Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+    Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
     Iterator output(out, window);
 
     execute_window_loop(window, [&](const Coordinates & id)
@@ -64,8 +66,8 @@
 
 void sub_saturate_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
 {
-    Iterator input1(in1, window);
-    Iterator input2(in2, window);
+    Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+    Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
     Iterator output(out, window);
 
     execute_window_loop(window, [&](const Coordinates & id)
@@ -80,8 +82,8 @@
 
 void sub_wrap_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
 {
-    Iterator input1(in1, window);
-    Iterator input2(in2, window);
+    Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+    Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
     Iterator output(out, window);
 
     execute_window_loop(window, [&](const Coordinates & id)
@@ -104,8 +106,8 @@
 
 void sub_saturate_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
 {
-    Iterator input1(in1, window);
-    Iterator input2(in2, window);
+    Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+    Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
     Iterator output(out, window);
 
     execute_window_loop(window, [&](const Coordinates & id)
@@ -144,8 +146,8 @@
 void sub_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
 {
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-    Iterator input1(in1, window);
-    Iterator input2(in2, window);
+    Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+    Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
     Iterator output(out, window);
 
     execute_window_loop(window, [&](const Coordinates & id)
@@ -167,8 +169,8 @@
 
 void sub_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
 {
-    Iterator input1(in1, window);
-    Iterator input2(in2, window);
+    Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+    Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
     Iterator output(out, window);
 
     execute_window_loop(window, [&](const Coordinates & id)
@@ -192,8 +194,8 @@
 }
 void sub_wrap_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
 {
-    Iterator input1(in1, window);
-    Iterator input2(in2, window);
+    Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+    Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
     Iterator output(out, window);
 
     execute_window_loop(window, [&](const Coordinates & id)
@@ -213,8 +215,8 @@
 
 void sub_saturate_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
 {
-    Iterator input1(in1, window);
-    Iterator input2(in2, window);
+    Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+    Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
     Iterator output(out, window);
 
     execute_window_loop(window, [&](const Coordinates & id)
@@ -234,8 +236,8 @@
 
 void sub_wrap_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
 {
-    Iterator input1(in1, window);
-    Iterator input2(in2, window);
+    Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+    Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
     Iterator output(out, window);
 
     execute_window_loop(window, [&](const Coordinates & id)
@@ -255,8 +257,8 @@
 
 void sub_saturate_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
 {
-    Iterator input1(in1, window);
-    Iterator input2(in2, window);
+    Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+    Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
     Iterator output(out, window);
 
     execute_window_loop(window, [&](const Coordinates & id)
@@ -276,8 +278,8 @@
 
 void sub_wrap_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
 {
-    Iterator input1(in1, window);
-    Iterator input2(in2, window);
+    Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+    Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
     Iterator output(out, window);
 
     execute_window_loop(window, [&](const Coordinates & id)
@@ -298,8 +300,8 @@
 
 void sub_saturate_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
 {
-    Iterator input1(in1, window);
-    Iterator input2(in2, window);
+    Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+    Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
     Iterator output(out, window);
 
     execute_window_loop(window, [&](const Coordinates & id)
@@ -318,43 +320,71 @@
     input1, input2, output);
 }
 
-inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
+inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output, ConvertPolicy policy)
 {
     ARM_COMPUTE_UNUSED(policy);
-    ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
-    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
 
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(
-        !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::U8)
-        && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
-        && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
-        && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
-        && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
-        && !(input1->data_type() == DataType::F32 && input2->data_type() == DataType::F32 && output->data_type() == DataType::F32)
-        && !(input1->data_type() == DataType::F16 && input2->data_type() == DataType::F16 && output->data_type() == DataType::F16),
-        "You called subtract with the wrong image formats");
+    const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
 
+    // Validate in case of configured output
+    if(output.total_size() > 0)
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+            !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::U8)
+            && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
+            && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
+            && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
+            && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
+            && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32)
+            && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16),
+            "You called subtract with the wrong image formats");
+
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
+                                        "Wrong shape for output");
+    }
     return Status{};
 }
 
-inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
+inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo &input1, ITensorInfo &input2, ITensorInfo &output)
 {
-    constexpr unsigned int num_elems_processed_per_iteration = 16;
+    const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(input1, input2);
+    const TensorShape &out_shape    = broadcast_pair.first;
+    const ValidRegion &valid_region = broadcast_pair.second;
 
-    // Configure kernel window
-    Window                 win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration));
-    AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+    // Auto initialize output if not initialized
+    {
+        set_shape_if_empty(output, out_shape);
 
-    bool window_changed = update_window_and_padding(win,
-                                                    AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration),
-                                                    AccessWindowHorizontal(input2, 0, num_elems_processed_per_iteration),
-                                                    output_access);
+        if(input1.data_type() == DataType::S16 || input2.data_type() == DataType::S16)
+        {
+            set_format_if_unknown(output, Format::S16);
+        }
+        else if(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16)
+        {
+            set_format_if_unknown(output, Format::F16);
+        }
+        else if(input1.data_type() == DataType::F32 || input2.data_type() == DataType::F32)
+        {
+            set_format_if_unknown(output, Format::F32);
+        }
+    }
 
-    ValidRegion valid_region = intersect_valid_regions(input1->valid_region(),
-                                                       input2->valid_region());
+    Window win        = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration));
+    Window win_input1 = win.broadcast_if_dimension_le_one(input1);
+    Window win_input2 = win.broadcast_if_dimension_le_one(input2);
+
+    AccessWindowHorizontal input1_access(&input1, 0, num_elems_processed_per_iteration);
+    AccessWindowHorizontal input2_access(&input2, 0, num_elems_processed_per_iteration);
+    AccessWindowHorizontal output_access(&output, 0, num_elems_processed_per_iteration);
+
+    bool window_changed = update_window_and_padding(win_input1, input1_access)
+                          || update_window_and_padding(win_input2, input2_access)
+                          || update_window_and_padding(win, output_access);
 
     output_access.set_valid_region(win, valid_region);
 
@@ -371,26 +401,11 @@
 void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info(), policy));
 
-    // Auto initialize output if not initialized
-    {
-        set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
-
-        if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
-        {
-            set_format_if_unknown(*output->info(), Format::S16);
-        }
-        else if(input1->info()->data_type() == DataType::F16 || input2->info()->data_type() == DataType::F16)
-        {
-            set_format_if_unknown(*output->info(), Format::F16);
-        }
-        else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
-        {
-            set_format_if_unknown(*output->info(), Format::F32);
-        }
-    }
-
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(), policy));
+    // Configure kernel window
+    auto win_config = validate_and_configure_window(*input1->info(), *input2->info(), *output->info());
+    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
 
     static std::map<std::string, NEArithmeticSubtractionKernel::SubFunction *> map_function =
     {
@@ -427,16 +442,15 @@
         _func = it->second;
     }
 
-    // Configure kernel window
-    auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
-    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     INEKernel::configure(win_config.second);
 }
 
 Status NEArithmeticSubtractionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, policy));
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
+    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
+
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output, policy));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(*input1->clone(), *input2->clone(), *output->clone()).first);
 
     return Status{};
 }
@@ -450,3 +464,10 @@
 
     (*_func)(_input1, _input2, _output, window);
 }
+
+BorderSize NEArithmeticSubtractionKernel::border_size() const
+{
+    const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0));
+    const unsigned int border        = std::min<unsigned int>(num_elems_processed_per_iteration - 1U, replicateSize);
+    return BorderSize(0, border, 0, 0);
+}
\ No newline at end of file