blob: 090d18ef65e33da0d5c7744a50db08f197529eb5 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
Sadik Armagandc032fc2021-01-19 17:24:21 +000011#include <tensorflow/lite/kernels/internal/tensor_ctypes.h>
Sadik Armagan62483be2020-10-23 17:14:43 +010012#include <tensorflow/lite/minimal_logging.h>
13
14namespace armnnDelegate
15{
16
17TfLiteStatus VisitArgMinMaxOperator(DelegateData& delegateData,
18 TfLiteContext* tfLiteContext,
19 TfLiteNode* tfLiteNode,
20 int nodeIndex,
21 int32_t argMinMaxOperatorCode)
22{
Sadik Armagandc032fc2021-01-19 17:24:21 +000023 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +000025
Sadik Armagandc032fc2021-01-19 17:24:21 +000026 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
28 if (!IsValid(tfLiteContext, tfLiteInputTensor, argMinMaxOperatorCode, nodeIndex))
29 {
30 return kTfLiteError;
31 }
32
33 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
34 if (!IsValid(tfLiteContext, tfLiteOutputTensor, argMinMaxOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
40 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
41
42 // Get const axis value from model and set it to descriptor.
43 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
44 if (!IsValid(tfLiteContext, tfLiteAxisTensor, argMinMaxOperatorCode, nodeIndex))
45 {
46 return kTfLiteError;
47 }
48
49 armnn::ArgMinMaxDescriptor desc;
50 // Get the axis value from the input tensor
51 switch (tfLiteAxisTensor.type)
52 {
53 case kTfLiteInt32:
54 case kTfLiteInt64:
55 desc.m_Axis = tflite::GetTensorData<int>(&tfLiteAxisTensor)[0];
56 break;
57 default:
58 TF_LITE_MAYBE_KERNEL_LOG(
59 tfLiteContext,
60 "TfLiteArmnnDelegate: Axis value data type is not supported in operator #%d node #%d: ",
61 argMinMaxOperatorCode, nodeIndex);
62 return kTfLiteError;
63 }
64
65 // If output_type is int32 then set Signed32 else Signed64. Default type is Signed64.
66 if (argMinMaxOperatorCode == kTfLiteBuiltinArgMax)
67 {
68 desc.m_Function = armnn::ArgMinMaxFunction::Max;
69 auto* argMaxParameters = reinterpret_cast<TfLiteArgMaxParams*>(tfLiteNode->builtin_data);
70 switch (argMaxParameters->output_type)
71 {
72 case kTfLiteInt32:
73 desc.m_Output_Type = armnn::DataType::Signed32;
74 break;
75 case kTfLiteInt64:
76 desc.m_Output_Type = armnn::DataType::Signed64;
77 break;
78 default:
79 TF_LITE_MAYBE_KERNEL_LOG(
80 tfLiteContext,
81 "TfLiteArmnnDelegate: output_type data type is not supported in operator #%d node #%d: ",
82 argMinMaxOperatorCode, nodeIndex);
83 return kTfLiteError;
84 }
85 }
86 else
87 {
88 desc.m_Function = armnn::ArgMinMaxFunction::Min;
89 auto* argMinParameters = reinterpret_cast<TfLiteArgMinParams*>(tfLiteNode->builtin_data);
90 switch (argMinParameters->output_type)
91 {
92 case kTfLiteInt32:
93 desc.m_Output_Type = armnn::DataType::Signed32;
94 break;
95 case kTfLiteInt64:
96 desc.m_Output_Type = armnn::DataType::Signed64;
97 break;
98 default:
99 TF_LITE_MAYBE_KERNEL_LOG(
100 tfLiteContext,
101 "TfLiteArmnnDelegate: output_type data type is not supported in operator #%d node #%d: ",
102 argMinMaxOperatorCode, nodeIndex);
103 return kTfLiteError;
104 }
105 }
106
107 bool isSupported = false;
108 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
109 {
110 FORWARD_LAYER_SUPPORT_FUNC(__func__,
111 tfLiteContext,
112 IsArgMinMaxSupported,
113 delegateData.m_Backends,
114 isSupported,
115 inputTensorInfo,
116 outInfo,
117 desc);
118 };
119
120 if (!delegateData.m_Network)
121 {
122 validateFunc(outputTensorInfo, isSupported);
123 return isSupported ? kTfLiteOk : kTfLiteError;
124 }
125
126 // Add an ArgMinMax layer
127 armnn::IConnectableLayer* layer = delegateData.m_Network->AddArgMinMaxLayer(desc);
128 ARMNN_ASSERT(layer != nullptr);
129
130 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
131 outputSlot.SetTensorInfo(outputTensorInfo);
132
133 // Connect
134 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100135}
136
137} // namespace armnnDelegate