Add int32 and int64 ArgMax op support

This patch adds int32 and int64 ArgMax op support.

Current ARMNN already has ArgMax op but not used, and
it doesn't support int64 output type.

So this patch adds a new type, Signed64, and also adds
ArgMinMax computation function for int64 type support.

In default, output tensor type of ArgMax op is int64 in case of
tensorflow lite model so this patch makes a proper function - ArgMax op
for int64 or int32 - to be called according to parsed output_type value.

With this patch, ARMNN supports both types - int64 and int32 - for
ArgMinMax op.

Changelog v1:
- Check if output data type of ArgMinMax op is valid or not.
- Use template function to support int32 and int64 types of ArgMinMax function.
- Keep using Signed32 as default data type of m_Output_Type.

Change-Id: I7a8e7e38dd9e5acc81464571d8b4d51378fc7f14
Signed-off-by: Inki Dae <inki.dae@samsung.com>
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 241b23d..2834336 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -53,17 +53,20 @@
     ArgMinMaxDescriptor()
         : m_Function(ArgMinMaxFunction::Min)
         , m_Axis(-1)
+        , m_Output_Type(armnn::DataType::Signed32)
     {}
 
     bool operator ==(const ArgMinMaxDescriptor &rhs) const
     {
-        return m_Function == rhs.m_Function && m_Axis == rhs.m_Axis;
+        return m_Function == rhs.m_Function && m_Axis == rhs.m_Axis && m_Output_Type == rhs.m_Output_Type;
     }
 
     /// Specify if the function is to find Min or Max.
     ArgMinMaxFunction m_Function;
     /// Axis to reduce across the input tensor.
     int m_Axis;
+    // Tensor data type and this could be int32 or int64. Default type is int64.
+    armnn::DataType m_Output_Type;
 };
 
 /// A ComparisonDescriptor for the ComparisonLayer
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index 11d807c..4a01549 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -41,6 +41,7 @@
     QSymmS8 = 7,
     QAsymmS8 = 8,
     BFloat16 = 9,
+    Signed64 = 10,
 
     QuantisedAsymm8 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QAsymmU8 instead.") = QAsymmU8,
     QuantisedSymm16 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QSymmS16 instead.") = QSymmS16
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index a2b3c95..efc69de 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -120,6 +120,7 @@
         case DataType::Float16:               return 2U;
         case DataType::Float32:
         case DataType::Signed32:              return 4U;
+        case DataType::Signed64:              return 8U;
         case DataType::QAsymmU8:              return 1U;
         case DataType::QAsymmS8:              return 1U;
         case DataType::QSymmS8:               return 1U;
@@ -171,6 +172,7 @@
     {
         case DataType::Float16:               return "Float16";
         case DataType::Float32:               return "Float32";
+        case DataType::Signed64:              return "Signed64";
         case DataType::QAsymmU8:              return "QAsymmU8";
         case DataType::QAsymmS8:              return "QAsymmS8";
         case DataType::QSymmS8:               return "QSymmS8";
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 6143f4a..0aad048 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -345,7 +345,9 @@
         case tflite::TensorType_INT32:
             type = armnn::DataType::Signed32;
             break;
-
+        case tflite::TensorType_INT64:
+            type = armnn::DataType::Signed64;
+            break;
         default:
         {
             CheckLocation location = CHECK_LOCATION();
@@ -598,6 +600,7 @@
     m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE_CONV]          = &TfLiteParser::ParseTransposeConv;
     m_ParserFunctions[tflite::BuiltinOperator_UNPACK]                  = &TfLiteParser::ParseUnpack;
     m_ParserFunctions[tflite::BuiltinOperator_DIV]                     = &TfLiteParser::ParseDiv;
+    m_ParserFunctions[tflite::BuiltinOperator_ARG_MAX]                 = &TfLiteParser::ParseArgMax;
     // register supported custom operators
     m_CustomParserFunctions["TFLite_Detection_PostProcess"]      = &TfLiteParser::ParseDetectionPostProcess;
 }
@@ -2847,6 +2850,47 @@
     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
 }
 
+void TfLiteParser::ParseArgMax(size_t subgraphIndex, size_t operatorIndex)
+{
+    const auto &operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
+    const auto *options = operatorPtr->builtin_options.AsArgMaxOptions();
+
+    CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+    auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+    CHECK_VALID_SIZE(inputs.size(), 2);
+
+    auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+    CHECK_VALID_SIZE(outputs.size(), 1);
+
+    auto layerName = boost::str(boost::format("ArgMax:%1%:%2%") % subgraphIndex % operatorIndex);
+
+    armnn::TensorInfo sizeTensorInfo0 = ToTensorInfo(inputs[0]);
+    armnn::TensorInfo sizeTensorInfo1 = ToTensorInfo(inputs[1]);
+
+    // Get const axis value from model and set it to descriptor.
+    BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+
+    ArgMinMaxDescriptor desc;
+    desc.m_Axis = axisBufferPtr->data.data()[0];
+    // If output_type is int32 then set Signed32 else Signed64. Default type is Signed64.
+    desc.m_Output_Type = options->output_type == 3 ? armnn::DataType::Signed32 : armnn::DataType::Signed64;
+    desc.m_Function = ArgMinMaxFunction::Max;
+
+    // Register a ArgMax layer.
+    IConnectableLayer *layer = m_Network->AddArgMinMaxLayer(desc, layerName.c_str());
+
+    armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+    layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+    // Register input tensor to the layer.
+    auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+    RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+
+    // Register output tensor to the layer.
+    auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+    RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
+}
+
 armnn::IConnectableLayer* TfLiteParser::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer,
                                                                 unsigned int outputSlot,
                                                                 tflite::ActivationFunctionType activationType)
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 6a61150..9b081a5 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -137,6 +137,7 @@
     void ParseTranspose(size_t subgraphIndex, size_t operatorIndex);
     void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex);
     void ParseUnpack(size_t subgraphIndex, size_t operatorIndex);
+    void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
 
     void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
     void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index f933505..98b5ada 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -31,6 +31,8 @@
             return arm_compute::DataType::QASYMM8;
         case armnn::DataType::QSymmS16:
             return arm_compute::DataType::QSYMM16;
+        case armnn::DataType::Signed64:
+            return arm_compute::DataType::S64;
         case armnn::DataType::QSymmS8:
         {
             return multiScales ? arm_compute::DataType::QSYMM8_PER_CHANNEL : arm_compute::DataType::QSYMM8;
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 07ce14b..ff97fc7 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -623,9 +623,10 @@
     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
 
-    if (outputTensorInfo.GetDataType() != DataType::Signed32)
+    if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
+        outputTensorInfo.GetDataType() != DataType::Signed64)
     {
-        throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32.");
+        throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
     }
 
     std::vector<DataType> supportedInputTypes =
@@ -636,7 +637,8 @@
         DataType::QAsymmS8,
         DataType::QAsymmU8,
         DataType::QSymmS16,
-        DataType::Signed32
+        DataType::Signed32,
+        DataType::Signed64
     };
 
     ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
diff --git a/src/backends/reference/test/ArgMinMaxTests.cpp b/src/backends/reference/test/ArgMinMaxTests.cpp
index 201a2c0..dce15b2 100644
--- a/src/backends/reference/test/ArgMinMaxTests.cpp
+++ b/src/backends/reference/test/ArgMinMaxTests.cpp
@@ -12,11 +12,11 @@
 BOOST_AUTO_TEST_CASE(ArgMinTest)
 {
     const armnn::TensorInfo inputInfo({ 1, 2, 3 } , armnn::DataType::Float32);
-    const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Float32);
+    const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Signed64);
 
     std::vector<float> inputValues({ 1.0f, 5.0f, 3.0f, 4.0f, 2.0f, 6.0f});
-    std::vector<int32_t> outputValues(outputInfo.GetNumElements());
-    std::vector<int32_t> expectedValues({ 0, 1, 0 });
+    std::vector<int64_t> outputValues(outputInfo.GetNumElements());
+    std::vector<int64_t> expectedValues({ 0, 1, 0 });
 
     ArgMinMax(*armnn::MakeDecoder<float>(inputInfo, inputValues.data()),
                outputValues.data(),
@@ -35,11 +35,11 @@
 BOOST_AUTO_TEST_CASE(ArgMaxTest)
 {
     const armnn::TensorInfo inputInfo({ 1, 2, 3 } , armnn::DataType::Float32);
-    const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Float32);
+    const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Signed64);
 
     std::vector<float> inputValues({ 1.0f, 5.0f, 3.0f, 4.0f, 2.0f, 6.0f });
-    std::vector<int32_t> outputValues(outputInfo.GetNumElements());
-    std::vector<int32_t> expectedValues({ 1, 0, 1 });
+    std::vector<int64_t> outputValues(outputInfo.GetNumElements());
+    std::vector<int64_t> expectedValues({ 1, 0, 1 });
 
     ArgMinMax(*armnn::MakeDecoder<float>(inputInfo, inputValues.data()),
                outputValues.data(),
diff --git a/src/backends/reference/workloads/ArgMinMax.cpp b/src/backends/reference/workloads/ArgMinMax.cpp
index c455c52..3bf2853 100644
--- a/src/backends/reference/workloads/ArgMinMax.cpp
+++ b/src/backends/reference/workloads/ArgMinMax.cpp
@@ -12,7 +12,8 @@
 namespace armnn
 {
 
-void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
+template <typename OUT>
+void ArgMinMax(Decoder<float>& in, OUT* out, const TensorInfo& inputTensorInfo,
                const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis)
 {
     IgnoreUnused(outputTensorInfo);
@@ -39,9 +40,16 @@
                     tmpIndex = i;
                 }
             }
-            out[outer * innerElements + inner] = armnn::numeric_cast<int32_t>(tmpIndex);
+
+            out[outer * innerElements + inner] = armnn::numeric_cast<OUT>(tmpIndex);
         }
     }
 }
 
+template void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
+               const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
+
+template void ArgMinMax(Decoder<float>& in, int64_t* out, const TensorInfo& inputTensorInfo,
+               const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
+
 } //namespace armnn
diff --git a/src/backends/reference/workloads/ArgMinMax.hpp b/src/backends/reference/workloads/ArgMinMax.hpp
index 5a9c6a8..3958ed7 100644
--- a/src/backends/reference/workloads/ArgMinMax.hpp
+++ b/src/backends/reference/workloads/ArgMinMax.hpp
@@ -13,7 +13,8 @@
 namespace armnn
 {
 
-void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
+template <typename OUT>
+void ArgMinMax(Decoder<float>& in, OUT *out, const TensorInfo& inputTensorInfo,
                const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
 
 } //namespace armnn
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 937a320..cd9efc9 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -5,8 +5,6 @@
 
 list(APPEND armnnRefBackendWorkloads_sources
     Abs.hpp
-    ArgMinMax.cpp
-    ArgMinMax.hpp
     Activation.cpp
     Activation.hpp
     ArgMinMax.cpp
diff --git a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp
index 5f1eb73..b7246d5 100644
--- a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp
+++ b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp
@@ -29,10 +29,15 @@
 
     const TensorInfo &outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]);
 
-    int32_t* output = GetOutputTensorData<int32_t>(0, m_Data);
-
-    ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function,
-              m_Data.m_Parameters.m_Axis);
+    if (m_Data.m_Parameters.m_Output_Type == armnn::DataType::Signed32) {
+        int32_t *output = GetOutputTensorData<int32_t>(0, m_Data);
+        ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function,
+                  m_Data.m_Parameters.m_Axis);
+    } else {
+        int64_t *output = GetOutputTensorData<int64_t>(0, m_Data);
+        ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function,
+                  m_Data.m_Parameters.m_Axis);
+    }
 }
 
 } //namespace armnn
\ No newline at end of file