COMPMID-417: Auto configuration for Add/Sub/Mul Neon/CL.

Change-Id: I3580de76bc53d342b53443d1077b1407d75a672a
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79570
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com>
diff --git a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
index aaa62d0..0cb0847 100644
--- a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
+++ b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
@@ -48,9 +48,32 @@
 
 void CLArithmeticAdditionKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, ConvertPolicy policy)
 {
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
+
+    // Auto initialize output if not initialized
+    {
+        set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
+
+        if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
+        {
+            set_format_if_unknown(*output->info(), Format::S16);
+        }
+        else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
+        {
+            set_format_if_unknown(*output->info(), Format::F32);
+        }
+        else if(input1->info()->data_type() == DataType::F16 && input2->info()->data_type() == DataType::F16)
+        {
+            set_format_if_unknown(*output->info(), Format::F16);
+        }
+    }
+
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
+                             "Output can only be U8 if both inputs are U8");
 
     _input1 = input1;
     _input2 = input2;
@@ -58,12 +81,6 @@
 
     const bool has_float_out = is_data_type_float(output->info()->data_type());
 
-    // Check for invalid combination
-    if(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8))
-    {
-        ARM_COMPUTE_ERROR("You called with the wrong data types.");
-    }
-
     // Set kernel build options
     std::set<std::string> build_opts;
     build_opts.emplace((policy == ConvertPolicy::WRAP || has_float_out) ? "-DWRAP" : "-DSATURATE");
diff --git a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
index 4c84727..69f9ff1 100644
--- a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
+++ b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
@@ -45,18 +45,28 @@
 
 void CLArithmeticSubtractionKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, ConvertPolicy policy)
 {
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
+
+    // Auto initialize output if not initialized
+    {
+        set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
+
+        if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
+        {
+            set_format_if_unknown(*output->info(), Format::S16);
+        }
+        else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
+        {
+            set_format_if_unknown(*output->info(), Format::F32);
+        }
+    }
+
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
-    // Check for invalid combination
-    if(output->info()->data_type() == DataType::U8)
-    {
-        ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8);
-        ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8);
-    }
-    else
-    {
-        ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
-        ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
-    }
+    ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
+                             "Output can only be U8 if both inputs are U8");
 
     _input1 = input1;
     _input2 = input2;
diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
index 84eb434..da417a9 100644
--- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
+++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
@@ -48,6 +48,23 @@
 void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale,
                                                 ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
 {
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
+
+    // Auto initialize output if not initialized
+    {
+        set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
+
+        if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
+        {
+            set_format_if_unknown(*output->info(), Format::S16);
+        }
+        else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
+        {
+            set_format_if_unknown(*output->info(), Format::F32);
+        }
+    }
+
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
diff --git a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
index a4fdad8..60b8006 100644
--- a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
@@ -294,15 +294,18 @@
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
 
-    set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
+    // Auto initialize output if not initialized
+    {
+        set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
 
-    if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
-    {
-        set_format_if_unknown(*output->info(), Format::S16);
-    }
-    else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
-    {
-        set_format_if_unknown(*output->info(), Format::F32);
+        if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
+        {
+            set_format_if_unknown(*output->info(), Format::S16);
+        }
+        else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
+        {
+            set_format_if_unknown(*output->info(), Format::F32);
+        }
     }
 
     ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
diff --git a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
index d3e62b0..f51b6b9 100644
--- a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
@@ -287,15 +287,18 @@
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
 
-    set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
+    // Auto initialize output if not initialized
+    {
+        set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
 
-    if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
-    {
-        set_format_if_unknown(*output->info(), Format::S16);
-    }
-    else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
-    {
-        set_format_if_unknown(*output->info(), Format::F32);
+        if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
+        {
+            set_format_if_unknown(*output->info(), Format::S16);
+        }
+        else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
+        {
+            set_format_if_unknown(*output->info(), Format::F32);
+        }
     }
 
     ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index aa8c7a1..7c95147 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -333,6 +333,28 @@
 
 void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
 {
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
+
+    // Auto initialize output if not initialized
+    {
+        set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
+
+        if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
+        {
+            set_format_if_unknown(*output->info(), Format::S16);
+        }
+        else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
+        {
+            set_format_if_unknown(*output->info(), Format::F32);
+        }
+        else if(input1->info()->data_type() == DataType::QS8 && input2->info()->data_type() == DataType::QS8)
+        {
+            set_data_type_if_unknown(*output->info(), DataType::QS8);
+            set_fixed_point_position_if_zero(*output->info(), input1->info()->fixed_point_position());
+        }
+    }
+
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32);