blob: 046be83094066ecc2643cf32fb08f9e9bd71acef [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Matthew Sloyan2b04ec32023-04-26 11:42:46 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnOpaqueDelegate
16{
17
18TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
19 TfLiteOpaqueContext* tfLiteContext,
20 TfLiteOpaqueNode* tfLiteNode,
21 int nodeIndex,
22 int32_t tfLiteComparisonOperatorCode)
23{
24 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
25 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26
27 // Gather input indices and use to get input tensor.
28 int numInputs = 0;
29 const int* inputTensors;
30 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
31 {
32 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
33 tfLiteContext,
34 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
35 nodeIndex);
36 return kTfLiteError;
37 }
38
39 // Use input indices to get input tensors.
40 const TfLiteOpaqueTensor* tfLiteInputTensor0 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
41 if (!IsValid(tfLiteContext, tfLiteInputTensor0, tfLiteComparisonOperatorCode, nodeIndex))
42 {
43 return kTfLiteError;
44 }
45
46 const TfLiteOpaqueTensor* tfLiteInputTensor1 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
47 if (!IsValid(tfLiteContext, tfLiteInputTensor1, tfLiteComparisonOperatorCode, nodeIndex))
48 {
49 return kTfLiteError;
50 }
51
52 // Gather output indices and use to get output tensors.
53 int numOutputs = 0;
54 const int* outputTensors;
55 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
56 {
57 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
58 tfLiteContext,
59 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
60 nodeIndex);
61 return kTfLiteError;
62 }
63
64 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
65 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteComparisonOperatorCode, nodeIndex))
66 {
67 return kTfLiteError;
68 }
69
70 armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor0);
71 armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor1);
72 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
73
74 // Check if we need to expand the dims of the input tensor infos.
75 // This is required for a few of the backends.
76 if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
77 {
78 ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
79 }
80
81 armnn::ComparisonOperation comparisonOperation = armnn::ComparisonOperation::Equal;
82 switch(tfLiteComparisonOperatorCode)
83 {
84 case kTfLiteBuiltinEqual:
85 comparisonOperation = armnn::ComparisonOperation::Equal;
86 break;
87 case kTfLiteBuiltinGreater:
88 comparisonOperation = armnn::ComparisonOperation::Greater;
89 break;
90 case kTfLiteBuiltinGreaterEqual:
91 comparisonOperation = armnn::ComparisonOperation::GreaterOrEqual;
92 break;
93 case kTfLiteBuiltinLess:
94 comparisonOperation = armnn::ComparisonOperation::Less;
95 break;
96 case kTfLiteBuiltinLessEqual:
97 comparisonOperation = armnn::ComparisonOperation::LessOrEqual;
98 break;
99 case kTfLiteBuiltinNotEqual:
100 comparisonOperation = armnn::ComparisonOperation::NotEqual;
101 break;
102 default:
103 return kTfLiteError;
104 }
105
106 armnn::ComparisonDescriptor descriptor(comparisonOperation);
107 bool isSupported = false;
108 armnn::BackendId setBackend;
109 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
110 {
111 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("COMPARISON",
112 tfLiteContext,
113 IsComparisonSupported,
114 delegateData.m_Backends,
115 isSupported,
116 setBackend,
117 inputTensorInfo0,
118 inputTensorInfo1,
119 outputTensorInfo,
120 descriptor);
121 };
122
123 if (!delegateData.m_Network)
124 {
125 validateFunc(outputTensorInfo, isSupported);
126 return isSupported ? kTfLiteOk : kTfLiteError;
127 }
128
129 armnn::IConnectableLayer* comparisonLayer = delegateData.m_Network->AddComparisonLayer(descriptor);
130 comparisonLayer->SetBackendId(setBackend);
131 ARMNN_ASSERT(comparisonLayer != nullptr);
132
133 armnn::IOutputSlot& outputSlot = comparisonLayer->GetOutputSlot(0);
134 outputSlot.SetTensorInfo(outputTensorInfo);
135
136 // try to connect the Constant Inputs if there are any
137 if(ProcessInputs(comparisonLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
138 {
139 return kTfLiteError;
140 }
141
142 return Connect(comparisonLayer, tfLiteContext, tfLiteNode, delegateData);
143}
144
145} // namespace armnnDelegate