blob: ce12e9f7c34622522c7baaeef159a64df881c668 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagan8b9858d2020-11-09 08:26:22 +00008#include "DelegateUtils.hpp"
Finn Williams6f9f9902020-11-13 13:23:15 +00009#include <armnn/utility/IgnoreUnused.hpp>
Sadik Armagan8b9858d2020-11-09 08:26:22 +000010
Sadik Armagan62483be2020-10-23 17:14:43 +010011#include <tensorflow/lite/builtin_ops.h>
12#include <tensorflow/lite/c/builtin_op_data.h>
13#include <tensorflow/lite/c/common.h>
14#include <tensorflow/lite/minimal_logging.h>
15
16namespace armnnDelegate
17{
18
19TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
Sadik Armagan8b9858d2020-11-09 08:26:22 +000023 int32_t tfLiteComparisonOperatorCode)
Sadik Armagan62483be2020-10-23 17:14:43 +010024{
Sadik Armagan8b9858d2020-11-09 08:26:22 +000025 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
26 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
27
28 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
29 const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
30 if (IsDynamicTensor(tfLiteInputTensor0))
31 {
32 TF_LITE_MAYBE_KERNEL_LOG(
33 tfLiteContext,
34 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
35 tfLiteComparisonOperatorCode, nodeIndex);
Finn Williams6f9f9902020-11-13 13:23:15 +000036
Sadik Armagan8b9858d2020-11-09 08:26:22 +000037 return kTfLiteError;
38 }
39
40 const TfLiteTensor& tfLiteInputTensor1 = tfLiteTensors[tfLiteNode->inputs->data[1]];
41 if (IsDynamicTensor(tfLiteInputTensor1))
42 {
43 TF_LITE_MAYBE_KERNEL_LOG(
44 tfLiteContext,
45 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
46 tfLiteComparisonOperatorCode, nodeIndex);
47 return kTfLiteError;
48 }
49
50 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
51 if (IsDynamicTensor(tfLiteOutputTensor))
52 {
53 TF_LITE_MAYBE_KERNEL_LOG(
54 tfLiteContext,
55 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
56 tfLiteComparisonOperatorCode, nodeIndex);
57 return kTfLiteError;
58 }
59
60 const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
61 const armnn::TensorInfo& inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
62 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
63
64 armnn::ComparisonOperation comparisonOperation = armnn::ComparisonOperation::Equal;
65 switch(tfLiteComparisonOperatorCode)
66 {
67 case kTfLiteBuiltinEqual:
68 comparisonOperation = armnn::ComparisonOperation::Equal;
69 break;
70 case kTfLiteBuiltinGreater:
71 comparisonOperation = armnn::ComparisonOperation::Greater;
72 break;
73 case kTfLiteBuiltinGreaterEqual:
74 comparisonOperation = armnn::ComparisonOperation::GreaterOrEqual;
75 break;
76 case kTfLiteBuiltinLess:
77 comparisonOperation = armnn::ComparisonOperation::Less;
78 break;
79 case kTfLiteBuiltinLessEqual:
80 comparisonOperation = armnn::ComparisonOperation::LessOrEqual;
81 break;
82 case kTfLiteBuiltinNotEqual:
83 comparisonOperation = armnn::ComparisonOperation::NotEqual;
84 break;
85 default:
86 return kTfLiteError;
87 }
88
89 armnn::ComparisonDescriptor descriptor(comparisonOperation);
90 bool isSupported = false;
91
92 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
93 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000094 FORWARD_LAYER_SUPPORT_FUNC("COMPARISON",
Sadik Armagan8b9858d2020-11-09 08:26:22 +000095 tfLiteContext,
96 IsComparisonSupported,
97 delegateData.m_Backends,
98 isSupported,
99 inputTensorInfo0,
100 inputTensorInfo1,
101 outputTensorInfo,
102 descriptor);
103 };
104
105 if (!delegateData.m_Network)
106 {
107 validateFunc(outputTensorInfo, isSupported);
108 return isSupported ? kTfLiteOk : kTfLiteError;
109 }
110
111 armnn::IConnectableLayer* comparisonLayer = delegateData.m_Network->AddComparisonLayer(descriptor);
112 ARMNN_ASSERT(comparisonLayer != nullptr);
113
114 armnn::IOutputSlot& outputSlot = comparisonLayer->GetOutputSlot(0);
115 outputSlot.SetTensorInfo(outputTensorInfo);
116
117 auto reshapeLayer = BroadcastTensor(inputTensorInfo0,
118 inputTensorInfo1,
119 comparisonLayer,
120 tfLiteContext,
121 tfLiteNode,
122 delegateData);
123 if (!reshapeLayer)
124 {
125 return kTfLiteError;
126 }
127 return kTfLiteOk;
Sadik Armagan62483be2020-10-23 17:14:43 +0100128}
129
130} // namespace armnnDelegate