Add support for S64 output in NEArgMinMaxLayer

* NEArgMinMaxLayer uses NEReductionOperation to compute its result in S32

* We need to call NECast to convert from S32 to S64

* Resolves MLCE-1089

Change-Id: I6fded869b6076d7af1b9b3e70eb384f4ee82fd8a
Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10054
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/runtime/NEON/functions/NEArgMinMaxLayer.h b/arm_compute/runtime/NEON/functions/NEArgMinMaxLayer.h
index 4392de7..3bb50a0 100644
--- a/arm_compute/runtime/NEON/functions/NEArgMinMaxLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEArgMinMaxLayer.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,8 +24,6 @@
 #ifndef ARM_COMPUTE_NEARGMINMAXLAYER_H
 #define ARM_COMPUTE_NEARGMINMAXLAYER_H
 
-#include "arm_compute/runtime/NEON/functions/NEReductionOperation.h"
-
 #include "arm_compute/core/Types.h"
 #include "arm_compute/runtime/MemoryGroup.h"
 #include "arm_compute/runtime/NEON/INESimpleFunction.h"
@@ -33,7 +31,6 @@
 namespace arm_compute
 {
 class ITensor;
-
 /** Function to calculate the index of the minimum or maximum values in a
  *  tensor based on an axis.
  *
@@ -68,13 +65,13 @@
      * - All
      *
      * Valid data type configurations:
-     * |src            |dst        |
-     * |:--------------|:----------|
-     * |QASYMM8        |U32, S32   |
-     * |QASYMM8_SIGNED |U32, S32   |
-     * |S32            |U32, S32   |
-     * |F16            |U32, S32   |
-     * |F32            |U32, S32   |
+     * |src            |dst           |
+     * |:--------------|:-------------|
+     * |QASYMM8        |U32, S32      |
+     * |QASYMM8_SIGNED |U32, S32      |
+     * |S32            |U32, S32, S64 |
+     * |F16            |U32, S32      |
+     * |F32            |U32, S32      |
      *
      * @param[in]  input  Input source tensor. Data types supported: QASYMM8_SIGNED/QASYMM8/S32/F16/F32.
      * @param[in]  axis   Axis to find max/min index.
@@ -86,7 +83,7 @@
      *
      * @param[in] input  Input source tensor info. Data types supported: QASYMM8_SIGNED/QASYMM8/S32/F16/F32.
      * @param[in] axis   Axis to find max/min index.
-     * @param[in] output Output source tensor info. Data types supported: U32/S32.
+     * @param[in] output Output source tensor info. Data types supported: U32/S32/S64.
      * @param[in] op     Operation to perform: min or max
      *
      * @return a status
@@ -97,7 +94,8 @@
     void run() override;
 
 private:
-    std::unique_ptr<NEReductionOperation> _reduction_function;
+    struct Impl;
+    std::unique_ptr<Impl> _impl;
 };
 } // namespace arm_compute
 #endif /* ARM_COMPUTE_NEARGMINMAXLAYER_H */
diff --git a/docs/user_guide/operator_list.dox b/docs/user_guide/operator_list.dox
index 66b8988..e0b4541 100644
--- a/docs/user_guide/operator_list.dox
+++ b/docs/user_guide/operator_list.dox
@@ -126,7 +126,7 @@
     <tr><th>src<th>dst
     <tr><td>QASYMM8<td>U32, S32
     <tr><td>QASYMM8_SIGNED<td>U32, S32
-    <tr><td>S32<td>U32, S32
+    <tr><td>S32<td>U32, S32, S64
     <tr><td>F16<td>U32, S32
     <tr><td>F32<td>U32, S32
     </table>
diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox
index edc0c3b..8d25223 100644
--- a/docs/user_guide/release_version_and_change_log.dox
+++ b/docs/user_guide/release_version_and_change_log.dox
@@ -40,6 +40,11 @@
 @note Starting from release 22.05, 'master' branch is no longer being used, it has been replaced by 'main'. Please update your clone jobs accordingly.
 
 @section S2_2_changelog Changelog
+
+v23.11 Public major release
+   - Add support for input data type U64/S64 in CLCast and NECast.
+   - Add support for output data type S64 in NEArgMinMaxLayer and CLArgMinMaxLayer
+
 v23.08 Public major release
  - Deprecate the legacy 'libarm_compute_core' library. This library is an artifact of Compute Library's legacy library architecture and no longer serves any purpose.
  Users must no longer link their applications to this library and instead link only to the main `libarm_compute` library for core functionality.
diff --git a/src/cpu/kernels/CpuCastKernel.cpp b/src/cpu/kernels/CpuCastKernel.cpp
index 641dea4..d478328 100644
--- a/src/cpu/kernels/CpuCastKernel.cpp
+++ b/src/cpu/kernels/CpuCastKernel.cpp
@@ -103,15 +103,20 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
                                                          DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
                                                          DataType::F32, DataType::S32, DataType::S64, DataType::U64);
-#else // __aarch64__
+
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
+                                                         DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
+                                                         DataType::U32, DataType::S32, DataType::F32, DataType::S64);
+
+#else  // __aarch64__
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
                                                          DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
                                                          DataType::F32, DataType::S32);
-#endif // __aarch64__
 
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
                                                          DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
                                                          DataType::U32, DataType::S32, DataType::F32);
+#endif // __aarch64__
 
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::QASYMM8_SIGNED && (dst->data_type() != DataType::S16 && dst->data_type() != DataType::S32
                                                                                      && dst->data_type() != DataType::F16 && dst->data_type() != DataType::F32),
@@ -146,13 +151,15 @@
 
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::S32 && (dst->data_type() != DataType::QASYMM8_SIGNED && dst->data_type() != DataType::QASYMM8
                                                                           && dst->data_type() != DataType::F16
-                                                                          && dst->data_type() != DataType::F32 && dst->data_type() != DataType::U8),
-                                    "Only data_types supported [in] S32 ->  [out] QASYMM8, F16, F32, U8");
+                                                                          && dst->data_type() != DataType::F32
+                                                                          && dst->data_type() != DataType::U8
+                                                                          && dst->data_type() != DataType::S64),
+                                    "Only data_types supported [in] S32 ->  [out] QASYMM8, F16, F32, U8, S64");
 #ifdef __aarch64__
-     ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::S64 && dst->data_type() != DataType::F32,
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::S64 && dst->data_type() != DataType::F32,
                                     "Only data_types supported [in] S64 ->  [out] F32");
 
-     ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::U64 && dst->data_type() != DataType::F32,
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::U64 && dst->data_type() != DataType::F32,
                                     "Only data_types supported [in] U64 ->  [out] F32");
 #endif // __aarch64__
 
@@ -199,6 +206,28 @@
 }
 
 template <>
+inline void internal_neon_convert<int32_t, int64_t>(const int32_t *src_ptr, int64_t *dst_ptr)
+{
+    const int32x4x4_t texels =
+    {
+        {
+            vld1q_s32(src_ptr),
+            vld1q_s32(src_ptr + 4),
+            vld1q_s32(src_ptr + 8),
+            vld1q_s32(src_ptr + 12)
+        }
+    };
+    vst1q_s64(dst_ptr, vmovl_s32(vget_low_s32(texels.val[0])));
+    vst1q_s64(dst_ptr + 2, vmovl_s32(vget_high_s32(texels.val[0])));
+    vst1q_s64(dst_ptr + 4, vmovl_s32(vget_low_s32(texels.val[1])));
+    vst1q_s64(dst_ptr + 6, vmovl_s32(vget_high_s32(texels.val[1])));
+    vst1q_s64(dst_ptr + 8, vmovl_s32(vget_low_s32(texels.val[2])));
+    vst1q_s64(dst_ptr + 10, vmovl_s32(vget_high_s32(texels.val[2])));
+    vst1q_s64(dst_ptr + 12, vmovl_s32(vget_low_s32(texels.val[3])));
+    vst1q_s64(dst_ptr + 14, vmovl_s32(vget_high_s32(texels.val[3])));
+}
+
+template <>
 inline void internal_neon_convert<int64_t, float>(const int64_t *src_ptr, float *dst_ptr)
 {
     const float64x2x4_t texels0 =
@@ -1062,6 +1091,13 @@
         case DataType::S32:
             switch(_dst->info()->data_type())
             {
+#if __aarch64__
+                case DataType::S64:
+                {
+                    convert64<int32_t, int64_t>(src, dst, win, window_start_x, window_end_x, window_step_x);
+                    break;
+                }
+#endif // __aarch64__
                 case DataType::F16:
                 {
                     /* Down-conversion S32 -> F16 */
diff --git a/src/cpu/kernels/CpuCastKernel.h b/src/cpu/kernels/CpuCastKernel.h
index 7623736..d8e61e6 100644
--- a/src/cpu/kernels/CpuCastKernel.h
+++ b/src/cpu/kernels/CpuCastKernel.h
@@ -61,9 +61,11 @@
      *   - F32            -> QASYMM8_SIGNED, QASYMM8, BFLOAT16, F16, S32, U8
      *
      * @param[in]  src    The src tensor to convert. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/S32/S64/BFLOAT16/F16/F32.
-     * @param[out] dst    The dst tensor. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/U32/S32/BFLOAT16/F16/F32.
+     * @param[out] dst    The dst tensor. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/U32/S32/S64/BFLOAT16/F16/F32.
      * @param[in]  policy Conversion policy.
      *
+     * @note S64 is only supported in aarch64
+     *
      * @deprecated Support for BFLOAT16 will be removed in 23.05 release
      */
     void configure(const ITensorInfo *src, ITensorInfo *dst, ConvertPolicy policy);
diff --git a/src/runtime/NEON/functions/NEArgMinMaxLayer.cpp b/src/runtime/NEON/functions/NEArgMinMaxLayer.cpp
index 3876ae6..3ac127b 100644
--- a/src/runtime/NEON/functions/NEArgMinMaxLayer.cpp
+++ b/src/runtime/NEON/functions/NEArgMinMaxLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,22 +29,49 @@
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Validate.h"
+#include "arm_compute/runtime/NEON/functions/NECast.h"
+#include "arm_compute/runtime/NEON/functions/NEReductionOperation.h"
+#include "arm_compute/runtime/Tensor.h"
 #include "src/common/utils/Log.h"
 #include "src/core/NEON/kernels/NEReductionOperationKernel.h"
 
 namespace arm_compute
 {
+struct NEArgMinMaxLayer::Impl
+{
+    MemoryGroup                           memory_group{};
+    std::shared_ptr<IMemoryManager>       memory_manager{};
+    std::unique_ptr<NEReductionOperation> reduction_function{};
+    std::unique_ptr<NECast>               cast_function{};
+    std::unique_ptr<Tensor>               tmp_reduction_result{};
+};
+
 NEArgMinMaxLayer::~NEArgMinMaxLayer() = default;
 
 NEArgMinMaxLayer::NEArgMinMaxLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _reduction_function(std::make_unique<NEReductionOperation>())
+    : _impl(std::make_unique<Impl>())
 {
-    ARM_COMPUTE_UNUSED(memory_manager);
+    _impl->memory_manager = std::move(memory_manager);
 }
+
 void NEArgMinMaxLayer::configure(ITensor *input, int axis, ITensor *output, const ReductionOperation &op)
 {
     ARM_COMPUTE_LOG_PARAMS(input, axis, output, op);
-    _reduction_function->configure(input, output, axis, op, false);
+    _impl->reduction_function = std::make_unique<NEReductionOperation>();
+    if(output->info() && (output->info()->data_type() == DataType::S64 || output->info()->data_type() == DataType::U64))
+    {
+        _impl->memory_group         = MemoryGroup(std::move(_impl->memory_manager));
+        _impl->cast_function        = std::make_unique<NECast>();
+        _impl->tmp_reduction_result = std::make_unique<Tensor>();
+        _impl->reduction_function->configure(input, _impl->tmp_reduction_result.get(), axis, op, false);
+        _impl->cast_function->configure(_impl->tmp_reduction_result.get(), output, ConvertPolicy::SATURATE);
+        _impl->memory_group.manage(_impl->tmp_reduction_result.get());
+        _impl->tmp_reduction_result->allocator()->allocate();
+    }
+    else
+    {
+        _impl->reduction_function->configure(input, output, axis, op, false);
+    }
 }
 
 Status NEArgMinMaxLayer::validate(const ITensorInfo *input, int axis, const ITensorInfo *output, const ReductionOperation &op)
@@ -55,7 +82,12 @@
 
 void NEArgMinMaxLayer::run()
 {
-    _reduction_function->run();
+    MemoryGroupResourceScope scope_mg(_impl->memory_group);
+    _impl->reduction_function->run();
+    if(_impl->tmp_reduction_result != nullptr)
+    {
+        _impl->cast_function->run();
+    }
 }
 
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/tests/validation/NEON/ArgMinMax.cpp b/tests/validation/NEON/ArgMinMax.cpp
index 2e21a7d..c80c936 100644
--- a/tests/validation/NEON/ArgMinMax.cpp
+++ b/tests/validation/NEON/ArgMinMax.cpp
@@ -97,6 +97,8 @@
 using NEArgMinMaxValidationFixture_S32_S32 = NEArgMinMaxValidationFixture<int32_t, int32_t>;
 using NEArgMinMaxValidationFixture_F16_S32 = NEArgMinMaxValidationFixture<half, int32_t>;
 using NEArgMinMaxValidationFixture_F32_S32 = NEArgMinMaxValidationFixture<float, int32_t>;
+using NEArgMinMaxValidationFixture_F32_S64 = NEArgMinMaxValidationFixture<float, int64_t>;
+
 TEST_SUITE(S32)
 FIXTURE_DATA_TEST_CASE(RunSmallAxis0,
                        NEArgMinMaxValidationFixture_S32_S32,
@@ -182,6 +184,19 @@
     validate(Accessor(_target), _reference);
 }
 
+FIXTURE_DATA_TEST_CASE(RunSmall_F32_S64,
+                       NEArgMinMaxValidationFixture_F32_S64,
+                       framework::DatasetMode::PRECOMMIT,
+                       combine(combine(combine(combine(ArgMinMaxSmallDataset(),
+                                                       framework::dataset::make("DataTypeIn", DataType::F32)),
+                                               framework::dataset::make("DataTypeOut", DataType::S64)),
+                                       AxisDataset),
+                               OpsDataset))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+
 FIXTURE_DATA_TEST_CASE(RunLarge,
                        NEArgMinMaxValidationFixture_F32_S32,
                        framework::DatasetMode::NIGHTLY,