blob: 6d7700d191a8baadf70e8cfd5dbde3c096f6af89 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Teresa Charlinad1b3d72023-03-14 12:10:28 +00008#include <DelegateUtils.hpp>
Finn Williams6f9f9902020-11-13 13:23:15 +00009#include <armnn/utility/IgnoreUnused.hpp>
Sadik Armagan8b9858d2020-11-09 08:26:22 +000010
Sadik Armagan62483be2020-10-23 17:14:43 +010011#include <tensorflow/lite/builtin_ops.h>
12#include <tensorflow/lite/c/builtin_op_data.h>
13#include <tensorflow/lite/c/common.h>
14#include <tensorflow/lite/minimal_logging.h>
15
16namespace armnnDelegate
17{
18
19TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
Sadik Armagan8b9858d2020-11-09 08:26:22 +000023 int32_t tfLiteComparisonOperatorCode)
Sadik Armagan62483be2020-10-23 17:14:43 +010024{
Sadik Armagan8b9858d2020-11-09 08:26:22 +000025 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
26 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
27
28 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
29 const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
30 if (IsDynamicTensor(tfLiteInputTensor0))
31 {
32 TF_LITE_MAYBE_KERNEL_LOG(
33 tfLiteContext,
34 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
35 tfLiteComparisonOperatorCode, nodeIndex);
Finn Williams6f9f9902020-11-13 13:23:15 +000036
Sadik Armagan8b9858d2020-11-09 08:26:22 +000037 return kTfLiteError;
38 }
39
40 const TfLiteTensor& tfLiteInputTensor1 = tfLiteTensors[tfLiteNode->inputs->data[1]];
41 if (IsDynamicTensor(tfLiteInputTensor1))
42 {
43 TF_LITE_MAYBE_KERNEL_LOG(
44 tfLiteContext,
45 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
46 tfLiteComparisonOperatorCode, nodeIndex);
47 return kTfLiteError;
48 }
49
50 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
51 if (IsDynamicTensor(tfLiteOutputTensor))
52 {
53 TF_LITE_MAYBE_KERNEL_LOG(
54 tfLiteContext,
55 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
56 tfLiteComparisonOperatorCode, nodeIndex);
57 return kTfLiteError;
58 }
59
Ryan OSheaa544f0f2023-01-25 18:10:20 +000060 armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
61 armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
Sadik Armagan90a119b2022-08-05 16:12:49 +010062 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagan8b9858d2020-11-09 08:26:22 +000063
Ryan OSheaa544f0f2023-01-25 18:10:20 +000064 // Check if we need to expand the dims of any of the input tensor infos.
65 // This is required for a few of the backends.
66 if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
67 {
68 ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
69 }
70
Sadik Armagan8b9858d2020-11-09 08:26:22 +000071 armnn::ComparisonOperation comparisonOperation = armnn::ComparisonOperation::Equal;
72 switch(tfLiteComparisonOperatorCode)
73 {
74 case kTfLiteBuiltinEqual:
75 comparisonOperation = armnn::ComparisonOperation::Equal;
76 break;
77 case kTfLiteBuiltinGreater:
78 comparisonOperation = armnn::ComparisonOperation::Greater;
79 break;
80 case kTfLiteBuiltinGreaterEqual:
81 comparisonOperation = armnn::ComparisonOperation::GreaterOrEqual;
82 break;
83 case kTfLiteBuiltinLess:
84 comparisonOperation = armnn::ComparisonOperation::Less;
85 break;
86 case kTfLiteBuiltinLessEqual:
87 comparisonOperation = armnn::ComparisonOperation::LessOrEqual;
88 break;
89 case kTfLiteBuiltinNotEqual:
90 comparisonOperation = armnn::ComparisonOperation::NotEqual;
91 break;
92 default:
93 return kTfLiteError;
94 }
95
96 armnn::ComparisonDescriptor descriptor(comparisonOperation);
97 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +010098 armnn::BackendId setBackend;
Sadik Armagan8b9858d2020-11-09 08:26:22 +000099 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
100 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000101 FORWARD_LAYER_SUPPORT_FUNC("COMPARISON",
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000102 tfLiteContext,
103 IsComparisonSupported,
104 delegateData.m_Backends,
105 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100106 setBackend,
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000107 inputTensorInfo0,
108 inputTensorInfo1,
109 outputTensorInfo,
110 descriptor);
111 };
112
113 if (!delegateData.m_Network)
114 {
115 validateFunc(outputTensorInfo, isSupported);
116 return isSupported ? kTfLiteOk : kTfLiteError;
117 }
118
119 armnn::IConnectableLayer* comparisonLayer = delegateData.m_Network->AddComparisonLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100120 comparisonLayer->SetBackendId(setBackend);
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000121 ARMNN_ASSERT(comparisonLayer != nullptr);
122
123 armnn::IOutputSlot& outputSlot = comparisonLayer->GetOutputSlot(0);
124 outputSlot.SetTensorInfo(outputTensorInfo);
125
Ryan OShea4c231de2023-01-17 15:19:20 +0000126 // try to connect the Constant Inputs if there are any
127 if(ProcessInputs(comparisonLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
128 {
129 return kTfLiteError;
130 }
131
Ryan OSheaa544f0f2023-01-25 18:10:20 +0000132 return Connect(comparisonLayer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100133}
134
135} // namespace armnnDelegate