blob: ead08d1724aa2803aa0dd8a60898a84b2d5a64d2 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
9
Finn Williams6f9f9902020-11-13 13:23:15 +000010#include <armnn/utility/IgnoreUnused.hpp>
Sadik Armagan8b9858d2020-11-09 08:26:22 +000011
Sadik Armagan62483be2020-10-23 17:14:43 +010012#include <tensorflow/lite/builtin_ops.h>
13#include <tensorflow/lite/c/builtin_op_data.h>
14#include <tensorflow/lite/c/common.h>
15#include <tensorflow/lite/minimal_logging.h>
16
17namespace armnnDelegate
18{
19
20TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
21 TfLiteContext* tfLiteContext,
22 TfLiteNode* tfLiteNode,
23 int nodeIndex,
Sadik Armagan8b9858d2020-11-09 08:26:22 +000024 int32_t tfLiteComparisonOperatorCode)
Sadik Armagan62483be2020-10-23 17:14:43 +010025{
Sadik Armagan8b9858d2020-11-09 08:26:22 +000026 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
27 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
28
29 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
30 const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
31 if (IsDynamicTensor(tfLiteInputTensor0))
32 {
33 TF_LITE_MAYBE_KERNEL_LOG(
34 tfLiteContext,
35 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
36 tfLiteComparisonOperatorCode, nodeIndex);
Finn Williams6f9f9902020-11-13 13:23:15 +000037
Sadik Armagan8b9858d2020-11-09 08:26:22 +000038 return kTfLiteError;
39 }
40
41 const TfLiteTensor& tfLiteInputTensor1 = tfLiteTensors[tfLiteNode->inputs->data[1]];
42 if (IsDynamicTensor(tfLiteInputTensor1))
43 {
44 TF_LITE_MAYBE_KERNEL_LOG(
45 tfLiteContext,
46 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
47 tfLiteComparisonOperatorCode, nodeIndex);
48 return kTfLiteError;
49 }
50
51 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
52 if (IsDynamicTensor(tfLiteOutputTensor))
53 {
54 TF_LITE_MAYBE_KERNEL_LOG(
55 tfLiteContext,
56 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
57 tfLiteComparisonOperatorCode, nodeIndex);
58 return kTfLiteError;
59 }
60
Ryan OSheaa544f0f2023-01-25 18:10:20 +000061 armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
62 armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
Sadik Armagan90a119b2022-08-05 16:12:49 +010063 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagan8b9858d2020-11-09 08:26:22 +000064
Ryan OSheaa544f0f2023-01-25 18:10:20 +000065 // Check if we need to expand the dims of any of the input tensor infos.
66 // This is required for a few of the backends.
67 if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
68 {
69 ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
70 }
71
Sadik Armagan8b9858d2020-11-09 08:26:22 +000072 armnn::ComparisonOperation comparisonOperation = armnn::ComparisonOperation::Equal;
73 switch(tfLiteComparisonOperatorCode)
74 {
75 case kTfLiteBuiltinEqual:
76 comparisonOperation = armnn::ComparisonOperation::Equal;
77 break;
78 case kTfLiteBuiltinGreater:
79 comparisonOperation = armnn::ComparisonOperation::Greater;
80 break;
81 case kTfLiteBuiltinGreaterEqual:
82 comparisonOperation = armnn::ComparisonOperation::GreaterOrEqual;
83 break;
84 case kTfLiteBuiltinLess:
85 comparisonOperation = armnn::ComparisonOperation::Less;
86 break;
87 case kTfLiteBuiltinLessEqual:
88 comparisonOperation = armnn::ComparisonOperation::LessOrEqual;
89 break;
90 case kTfLiteBuiltinNotEqual:
91 comparisonOperation = armnn::ComparisonOperation::NotEqual;
92 break;
93 default:
94 return kTfLiteError;
95 }
96
97 armnn::ComparisonDescriptor descriptor(comparisonOperation);
98 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +010099 armnn::BackendId setBackend;
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000100 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
101 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000102 FORWARD_LAYER_SUPPORT_FUNC("COMPARISON",
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000103 tfLiteContext,
104 IsComparisonSupported,
105 delegateData.m_Backends,
106 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100107 setBackend,
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000108 inputTensorInfo0,
109 inputTensorInfo1,
110 outputTensorInfo,
111 descriptor);
112 };
113
114 if (!delegateData.m_Network)
115 {
116 validateFunc(outputTensorInfo, isSupported);
117 return isSupported ? kTfLiteOk : kTfLiteError;
118 }
119
Mike Kelly07169c82023-08-02 13:23:09 +0100120 auto layerName = GetLayerName(descriptor.m_Operation, nodeIndex);
121 armnn::IConnectableLayer* comparisonLayer = delegateData.m_Network->AddComparisonLayer(descriptor,
122 layerName.c_str());
Cathal Corbett53837672022-09-01 11:34:37 +0100123 comparisonLayer->SetBackendId(setBackend);
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000124 ARMNN_ASSERT(comparisonLayer != nullptr);
125
126 armnn::IOutputSlot& outputSlot = comparisonLayer->GetOutputSlot(0);
127 outputSlot.SetTensorInfo(outputTensorInfo);
128
Ryan OShea4c231de2023-01-17 15:19:20 +0000129 // try to connect the Constant Inputs if there are any
Mike Kelly07169c82023-08-02 13:23:09 +0100130 if (ProcessInputs(comparisonLayer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Ryan OShea4c231de2023-01-17 15:19:20 +0000131 {
132 return kTfLiteError;
133 }
134
Ryan OSheaa544f0f2023-01-25 18:10:20 +0000135 return Connect(comparisonLayer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100136}
137
138} // namespace armnnDelegate