IVGCVSW-5804 TfLiteParser fails to correctly parse ArgMinMax

 * Fix for GitHub#523.
 * Updated ParseArgMinMax function to read correct axis data.
 * Improved validation in ParseArgMinMax function.
 * Added ARG_MIN support to TfLiteParser.
 * Added ArgMinMax unit tests for TfLiteParser.

Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Change-Id: Ib4ce1a7c66e210c47859a130c4896aac958f2654
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8ccb270..edcf5cc 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -806,6 +806,7 @@
         list(APPEND unittest_sources
              src/armnnTfLiteParser/test/Activations.cpp
              src/armnnTfLiteParser/test/Addition.cpp
+             src/armnnTfLiteParser/test/ArgMinMax.cpp
              src/armnnTfLiteParser/test/AvgPool2D.cpp
              src/armnnTfLiteParser/test/BatchToSpaceND.cpp
              src/armnnTfLiteParser/test/Concatenation.cpp
diff --git a/docs/01_01_parsers.dox b/docs/01_01_parsers.dox
index ae49303..05e1225 100644
--- a/docs/01_01_parsers.dox
+++ b/docs/01_01_parsers.dox
@@ -158,6 +158,8 @@
 The Arm NN SDK TensorFlow Lite parser currently supports the following operators:
 
 - ADD
+- ARG_MAX
+_ ARG_MIN
 - AVERAGE_POOL_2D, Supported Fused Activation: RELU , RELU6 , TANH, NONE
 - BATCH_TO_SPACE
 - CONCATENATION, Supported Fused Activation: RELU , RELU6 , TANH, NONE
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index c4d2942..cb3426e 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -604,6 +604,8 @@
 {
     // register supported operators
     m_ParserFunctions[tflite::BuiltinOperator_ADD]                     = &TfLiteParserImpl::ParseAdd;
+    m_ParserFunctions[tflite::BuiltinOperator_ARG_MIN]                 = &TfLiteParserImpl::ParseArgMin;
+    m_ParserFunctions[tflite::BuiltinOperator_ARG_MAX]                 = &TfLiteParserImpl::ParseArgMax;
     m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D]         = &TfLiteParserImpl::ParseAveragePool2D;
     m_ParserFunctions[tflite::BuiltinOperator_BATCH_TO_SPACE_ND]       = &TfLiteParserImpl::ParseBatchToSpaceND;
     m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION]           = &TfLiteParserImpl::ParseConcatenation;
@@ -612,6 +614,7 @@
     m_ParserFunctions[tflite::BuiltinOperator_DEPTH_TO_SPACE]          = &TfLiteParserImpl::ParseDepthToSpace;
     m_ParserFunctions[tflite::BuiltinOperator_DEPTHWISE_CONV_2D]       = &TfLiteParserImpl::ParseDepthwiseConv2D;
     m_ParserFunctions[tflite::BuiltinOperator_DEQUANTIZE]              = &TfLiteParserImpl::ParseDequantize;
+    m_ParserFunctions[tflite::BuiltinOperator_DIV]                     = &TfLiteParserImpl::ParseDiv;
     m_ParserFunctions[tflite::BuiltinOperator_ELU]                     = &TfLiteParserImpl::ParseElu;
     m_ParserFunctions[tflite::BuiltinOperator_EXP]                     = &TfLiteParserImpl::ParseExp;
     m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED]         = &TfLiteParserImpl::ParseFullyConnected;
@@ -649,8 +652,7 @@
     m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE]               = &TfLiteParserImpl::ParseTranspose;
     m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE_CONV]          = &TfLiteParserImpl::ParseTransposeConv;
     m_ParserFunctions[tflite::BuiltinOperator_UNPACK]                  = &TfLiteParserImpl::ParseUnpack;
-    m_ParserFunctions[tflite::BuiltinOperator_DIV]                     = &TfLiteParserImpl::ParseDiv;
-    m_ParserFunctions[tflite::BuiltinOperator_ARG_MAX]                 = &TfLiteParserImpl::ParseArgMax;
+
     // register supported custom operators
     m_CustomParserFunctions["TFLite_Detection_PostProcess"]      = &TfLiteParserImpl::ParseDetectionPostProcess;
 }
@@ -2939,8 +2941,18 @@
     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
 }
 
+void TfLiteParserImpl::ParseArgMin(size_t subgraphIndex, size_t operatorIndex)
+{
+    ParseArgMinMax(subgraphIndex, operatorIndex, armnn::ArgMinMaxFunction::Min);
+}
+
 void TfLiteParserImpl::ParseArgMax(size_t subgraphIndex, size_t operatorIndex)
 {
+    ParseArgMinMax(subgraphIndex, operatorIndex, armnn::ArgMinMaxFunction::Max);
+}
+
+void TfLiteParserImpl::ParseArgMinMax(size_t subgraphIndex, size_t operatorIndex, ArgMinMaxFunction argMinMaxFunction)
+{
     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
     CHECK_VALID_SIZE(inputs.size(), 2);
@@ -2948,22 +2960,11 @@
     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
     CHECK_VALID_SIZE(outputs.size(), 1);
 
-    auto layerName = fmt::format("ArgMax:{}:{}", subgraphIndex, operatorIndex);
-
-    armnn::TensorInfo sizeTensorInfo0 = ToTensorInfo(inputs[0]);
-    armnn::TensorInfo sizeTensorInfo1 = ToTensorInfo(inputs[1]);
-
-    // Get const axis value from model and set it to descriptor.
-    BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
-
-    ArgMinMaxDescriptor desc;
-    desc.m_Axis = axisBufferPtr->data.data()[0];
-    desc.m_Function = ArgMinMaxFunction::Max;
-
-    // Register a ArgMax layer.
-    IConnectableLayer *layer = m_Network->AddArgMinMaxLayer(desc, layerName.c_str());
-
+    armnn::TensorInfo inputTensorInfo  = ToTensorInfo(inputs[0]);
+    armnn::TensorInfo axisTensorInfo   = ToTensorInfo(inputs[1]);
     armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+
+    // Check if output tensor type is Signed32 or Signed64
     if (outputTensorInfo.GetDataType() != armnn::DataType::Signed32 &&
         outputTensorInfo.GetDataType() != armnn::DataType::Signed64)
     {
@@ -2972,6 +2973,41 @@
                         "Output tensor data type is not supported. (Supported types: Signed32 & Signed64) {}",
                                 CHECK_LOCATION().AsString()));
     }
+
+    // Get const axis value from model and set it to descriptor.
+    BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+    if (axisBufferPtr == nullptr)
+    {
+        throw ParseException(
+                fmt::format("Operation has invalid inputs. Failed to read axis. {}",
+                            CHECK_LOCATION().AsString()));
+    }
+
+    std::vector<int32_t> axisData(axisTensorInfo.GetNumElements());
+    ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes());
+    int32_t axis = axisData.front();
+
+    auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
+    if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
+    {
+        // Square bracket denotes inclusive n while parenthesis denotes exclusive n
+        // E.g. Rank 4 tensor can have axis in range [-4, 3)
+        // -1 == 3, -2 == 2, -3 == 1, -4 == 0
+        throw ParseException(
+                fmt::format("Operation has invalid axis: {}. Axis must be in range [-n, n) {}",
+                                    axis,
+                                    CHECK_LOCATION().AsString()));
+    }
+
+    ArgMinMaxDescriptor desc;
+    desc.m_Axis = axis;
+    desc.m_Function = argMinMaxFunction;
+
+    // Register a ArgMin/ArgMax layer.
+    auto layerName = argMinMaxFunction == ArgMinMaxFunction::Max ? "ArgMax:{}:{}" : "ArgMin:{}:{}";
+    auto layerNameFormatted = fmt::format(layerName, subgraphIndex, operatorIndex);
+    IConnectableLayer *layer = m_Network->AddArgMinMaxLayer(desc, layerNameFormatted.c_str());
+    ARMNN_ASSERT(layer != nullptr);
     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
 
     // Register input tensor to the layer.
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 07ff481..90517f5 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -99,6 +99,9 @@
 
     void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType);
     void ParseAdd(size_t subgraphIndex, size_t operatorIndex);
+    void ParseArgMinMax(size_t subgraphIndex, size_t operatorIndex, armnn::ArgMinMaxFunction argMinMaxFunction);
+    void ParseArgMin(size_t subgraphIndex, size_t operatorIndex);
+    void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
     void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
     void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex);
     void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
@@ -107,6 +110,7 @@
     void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
     void ParseDequantize(size_t subgraphIndex, size_t operatorIndex);
     void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex);
+    void ParseDiv(size_t subgraphIndex, size_t operatorIndex);
     void ParseElu(size_t subgraphIndex, size_t operatorIndex);
     void ParseExp(size_t subgraphIndex, size_t operatorIndex);
     void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
@@ -143,12 +147,10 @@
     void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex);
     void ParseSub(size_t subgraphIndex, size_t operatorIndex);
     void ParseSum(size_t subgraphIndex, size_t operatorIndex);
-    void ParseDiv(size_t subgraphIndex, size_t operatorIndex);
     void ParseTanH(size_t subgraphIndex, size_t operatorIndex);
     void ParseTranspose(size_t subgraphIndex, size_t operatorIndex);
     void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex);
     void ParseUnpack(size_t subgraphIndex, size_t operatorIndex);
-    void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
 
     void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
     void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
diff --git a/src/armnnTfLiteParser/test/ArgMinMax.cpp b/src/armnnTfLiteParser/test/ArgMinMax.cpp
new file mode 100644
index 0000000..ad99b48
--- /dev/null
+++ b/src/armnnTfLiteParser/test/ArgMinMax.cpp
@@ -0,0 +1,164 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "ParserFlatbuffersFixture.hpp"
+#include "../TfLiteParser.hpp"
+
+#include <iostream>
+#include <string>
+
+BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
+
+struct ArgMinMaxFixture : public ParserFlatbuffersFixture
+{
+    explicit ArgMinMaxFixture(const std::string& operatorCode,
+                              const std::string& inputShape,
+                              const std::string& outputShape,
+                              const std::string& axisData)
+    {
+        m_JsonString = R"(
+            {
+                "version": 3,
+                "operator_codes": [ { "builtin_code": )" + operatorCode + R"( } ],
+                "subgraphs": [ {
+                    "tensors": [
+                        {
+                            "shape": )" + inputShape + R"(,
+                            "type": "FLOAT32",
+                            "buffer": 0,
+                            "name": "inputTensor",
+                            "quantization": {
+                                "min": [ 0.0 ],
+                                "max": [ 255.0 ],
+                                "scale": [ 1.0 ],
+                                "zero_point": [ 0 ],
+                            }
+                        },
+                        {
+                            "shape": )" + outputShape + R"( ,
+                            "type": "INT32",
+                            "buffer": 1,
+                            "name": "outputTensor",
+                            "quantization": {
+                                "min": [ 0.0 ],
+                                "max": [ 255.0 ],
+                                "scale": [ 1.0 ],
+                                "zero_point": [ 0 ],
+                            }
+                        },
+                        {
+                            "shape": [ 1 ],
+                            "type": "INT32",
+                            "buffer": 2,
+                            "name": "axis",
+                            "quantization": {
+                                "min": [ 0.0 ],
+                                "max": [ 255.0 ],
+                                "scale": [ 1.0 ],
+                                "zero_point": [ 0 ],
+                            }
+                        },
+                    ],
+                    "inputs": [ 0 ],
+                    "outputs": [ 1 ],
+                    "operators": [
+                        {
+                            "opcode_index": 0,
+                            "inputs": [ 0 , 2 ],
+                            "outputs": [ 1 ],
+                            "custom_options_format": "FLEXBUFFERS"
+                        }
+                    ],
+                } ],
+                "buffers" : [
+                    { },
+                    { },
+                    { "data": )" + axisData + R"(, },
+                ]
+            }
+        )";
+
+        SetupSingleInputSingleOutput("inputTensor", "outputTensor");
+    }
+};
+
+struct SimpleArgMaxFixture : public ArgMinMaxFixture
+{
+    SimpleArgMaxFixture() : ArgMinMaxFixture("ARG_MAX",
+                                             "[ 1, 1, 1, 5 ]",
+                                             "[ 1, 1, 1 ]",
+                                             "[ 3, 0, 0, 0 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseSimpleArgMax, SimpleArgMaxFixture)
+{
+    RunTest<3, armnn::DataType::Float32, armnn::DataType::Signed32>(
+            0,
+            {{ "inputTensor",  { 6.0f, 2.0f, 8.0f, 10.0f, 9.0f } } },
+            {{ "outputTensor", { 3l } } });
+}
+
+struct ArgMaxFixture : public ArgMinMaxFixture
+{
+    ArgMaxFixture() : ArgMinMaxFixture("ARG_MAX",
+                                       "[ 3, 2, 1, 4 ]",
+                                       "[ 2, 1, 4 ]",
+                                       "[ 0, 0, 0, 0 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseArgMax, ArgMaxFixture)
+{
+    RunTest<3, armnn::DataType::Float32, armnn::DataType::Signed32>(
+            0,
+            {{ "inputTensor", { 1.0f,   2.0f,   3.0f,   4.0f,
+                                8.0f,   7.0f,   6.0f,   6.0f,
+                                100.0f, 20.0f,  300.0f, 40.0f,
+                                500.0f, 476.0f, 450.0f, 426.0f,
+                                50.0f,  60.0f,  70.0f,  80.0f,
+                                10.0f,  200.0f, 30.0f,  400.0f } } },
+            {{ "outputTensor", { 1, 2, 1, 2,
+                                 1, 1, 1, 1 } } });
+}
+
+struct SimpleArgMinFixture : public ArgMinMaxFixture
+{
+    SimpleArgMinFixture() : ArgMinMaxFixture("ARG_MIN",
+                                             "[ 1, 1, 1, 5 ]",
+                                             "[ 1, 1, 1 ]",
+                                             "[ 3, 0, 0, 0 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseSimpleArgMin, SimpleArgMinFixture)
+{
+    RunTest<3, armnn::DataType::Float32, armnn::DataType::Signed32>(
+            0,
+            {{ "inputTensor",  { 6.0f, 2.0f, 8.0f, 10.0f, 9.0f } } },
+            {{ "outputTensor", { 1l } } });
+}
+
+struct ArgMinFixture : public ArgMinMaxFixture
+{
+    ArgMinFixture() : ArgMinMaxFixture("ARG_MIN",
+                                       "[ 3, 2, 1, 4 ]",
+                                       "[ 2, 1, 4 ]",
+                                       "[ 0, 0, 0, 0 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseArgMin, ArgMinFixture)
+{
+    RunTest<3, armnn::DataType::Float32, armnn::DataType::Signed32>(
+            0,
+            {{ "inputTensor", { 1.0f,   2.0f,   3.0f,   4.0f,
+                                8.0f,   7.0f,   6.0f,   6.0f,
+                                100.0f, 20.0f,  300.0f, 40.0f,
+                                500.0f, 476.0f, 450.0f, 426.0f,
+                                50.0f,  60.0f,  70.0f,  80.0f,
+                                10.0f,  200.0f, 30.0f,  400.0f } } },
+            {{ "outputTensor", { 0, 0, 0, 0,
+                                 0, 0, 0, 0 } } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()