blob: f787a220906b9bf020be945a76c82ad66d7461f5 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagan8b9858d2020-11-09 08:26:22 +00008#include "DelegateUtils.hpp"
9
Sadik Armagan62483be2020-10-23 17:14:43 +010010#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
18TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 TfLiteNode* tfLiteNode,
21 int nodeIndex,
Sadik Armagan8b9858d2020-11-09 08:26:22 +000022 int32_t tfLiteComparisonOperatorCode)
Sadik Armagan62483be2020-10-23 17:14:43 +010023{
Sadik Armagan8b9858d2020-11-09 08:26:22 +000024 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
25 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26
27 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
28 const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
29 if (IsDynamicTensor(tfLiteInputTensor0))
30 {
31 TF_LITE_MAYBE_KERNEL_LOG(
32 tfLiteContext,
33 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
34 tfLiteComparisonOperatorCode, nodeIndex);
35 return kTfLiteError;
36 }
37
38 const TfLiteTensor& tfLiteInputTensor1 = tfLiteTensors[tfLiteNode->inputs->data[1]];
39 if (IsDynamicTensor(tfLiteInputTensor1))
40 {
41 TF_LITE_MAYBE_KERNEL_LOG(
42 tfLiteContext,
43 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
44 tfLiteComparisonOperatorCode, nodeIndex);
45 return kTfLiteError;
46 }
47
48 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
49 if (IsDynamicTensor(tfLiteOutputTensor))
50 {
51 TF_LITE_MAYBE_KERNEL_LOG(
52 tfLiteContext,
53 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
54 tfLiteComparisonOperatorCode, nodeIndex);
55 return kTfLiteError;
56 }
57
58 const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
59 const armnn::TensorInfo& inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
60 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
61
62 armnn::ComparisonOperation comparisonOperation = armnn::ComparisonOperation::Equal;
63 switch(tfLiteComparisonOperatorCode)
64 {
65 case kTfLiteBuiltinEqual:
66 comparisonOperation = armnn::ComparisonOperation::Equal;
67 break;
68 case kTfLiteBuiltinGreater:
69 comparisonOperation = armnn::ComparisonOperation::Greater;
70 break;
71 case kTfLiteBuiltinGreaterEqual:
72 comparisonOperation = armnn::ComparisonOperation::GreaterOrEqual;
73 break;
74 case kTfLiteBuiltinLess:
75 comparisonOperation = armnn::ComparisonOperation::Less;
76 break;
77 case kTfLiteBuiltinLessEqual:
78 comparisonOperation = armnn::ComparisonOperation::LessOrEqual;
79 break;
80 case kTfLiteBuiltinNotEqual:
81 comparisonOperation = armnn::ComparisonOperation::NotEqual;
82 break;
83 default:
84 return kTfLiteError;
85 }
86
87 armnn::ComparisonDescriptor descriptor(comparisonOperation);
88 bool isSupported = false;
89
90 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
91 {
92 FORWARD_LAYER_SUPPORT_FUNC(__func__,
93 tfLiteContext,
94 IsComparisonSupported,
95 delegateData.m_Backends,
96 isSupported,
97 inputTensorInfo0,
98 inputTensorInfo1,
99 outputTensorInfo,
100 descriptor);
101 };
102
103 if (!delegateData.m_Network)
104 {
105 validateFunc(outputTensorInfo, isSupported);
106 return isSupported ? kTfLiteOk : kTfLiteError;
107 }
108
109 armnn::IConnectableLayer* comparisonLayer = delegateData.m_Network->AddComparisonLayer(descriptor);
110 ARMNN_ASSERT(comparisonLayer != nullptr);
111
112 armnn::IOutputSlot& outputSlot = comparisonLayer->GetOutputSlot(0);
113 outputSlot.SetTensorInfo(outputTensorInfo);
114
115 auto reshapeLayer = BroadcastTensor(inputTensorInfo0,
116 inputTensorInfo1,
117 comparisonLayer,
118 tfLiteContext,
119 tfLiteNode,
120 delegateData);
121 if (!reshapeLayer)
122 {
123 return kTfLiteError;
124 }
125 return kTfLiteOk;
Sadik Armagan62483be2020-10-23 17:14:43 +0100126}
127
128} // namespace armnnDelegate