blob: 026a43adaf320e8a21d434118ee697bc68898d4a [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Matthew Sloyan2b04ec32023-04-26 11:42:46 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
Matthew Sloyan2b04ec32023-04-26 11:42:46 +010010namespace armnnOpaqueDelegate
11{
12
Mike Kellya2806502023-08-03 10:42:11 +010013std::string GetOperationName(armnn::ComparisonOperation comparisonOperation)
Teresa Charlinf69ae562023-04-27 14:42:23 +010014{
15 std::string layerName = "COMPARISON";
16 switch (comparisonOperation)
17 {
18 case armnn::ComparisonOperation::NotEqual:
19 layerName += " NOT_EQUAL";
20 break;
21 case armnn::ComparisonOperation::Equal:
22 layerName += " EQUAL";
23 break;
24 case armnn::ComparisonOperation::Greater:
25 layerName += " GREATER";
26 break;
27 case armnn::ComparisonOperation::GreaterOrEqual:
28 layerName += " GREATER_OR_EQUAL";
29 break;
30 case armnn::ComparisonOperation::Less:
31 layerName += " LESS";
32 break;
33 case armnn::ComparisonOperation::LessOrEqual:
34 layerName += " LESS_OR_EQUAL";
35 break;
36 default:
37 layerName += " UNKNOWN";
38 }
39 return layerName;
40}
41
Matthew Sloyan2b04ec32023-04-26 11:42:46 +010042TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
43 TfLiteOpaqueContext* tfLiteContext,
44 TfLiteOpaqueNode* tfLiteNode,
45 int nodeIndex,
Teresa Charlinf69ae562023-04-27 14:42:23 +010046 int32_t tfLiteComparisonOperatorCode,
47 armnn::ComparisonOperation comparisonOperation)
Matthew Sloyan2b04ec32023-04-26 11:42:46 +010048{
49 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
50 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
51
52 // Gather input indices and use to get input tensor.
53 int numInputs = 0;
54 const int* inputTensors;
55 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
56 {
57 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
58 tfLiteContext,
59 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
60 nodeIndex);
61 return kTfLiteError;
62 }
63
64 // Use input indices to get input tensors.
65 const TfLiteOpaqueTensor* tfLiteInputTensor0 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
66 if (!IsValid(tfLiteContext, tfLiteInputTensor0, tfLiteComparisonOperatorCode, nodeIndex))
67 {
68 return kTfLiteError;
69 }
70
71 const TfLiteOpaqueTensor* tfLiteInputTensor1 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
72 if (!IsValid(tfLiteContext, tfLiteInputTensor1, tfLiteComparisonOperatorCode, nodeIndex))
73 {
74 return kTfLiteError;
75 }
76
77 // Gather output indices and use to get output tensors.
78 int numOutputs = 0;
79 const int* outputTensors;
80 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
81 {
82 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
83 tfLiteContext,
84 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
85 nodeIndex);
86 return kTfLiteError;
87 }
88
Teresa Charlinf69ae562023-04-27 14:42:23 +010089 // Use output indices to get output tensor.
Matthew Sloyan2b04ec32023-04-26 11:42:46 +010090 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
91 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteComparisonOperatorCode, nodeIndex))
92 {
93 return kTfLiteError;
94 }
95
96 armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor0);
97 armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor1);
98 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
99
100 // Check if we need to expand the dims of the input tensor infos.
101 // This is required for a few of the backends.
102 if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
103 {
104 ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
105 }
106
Matthew Sloyan2b04ec32023-04-26 11:42:46 +0100107 armnn::ComparisonDescriptor descriptor(comparisonOperation);
108 bool isSupported = false;
109 armnn::BackendId setBackend;
Teresa Charlinf69ae562023-04-27 14:42:23 +0100110 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported, std::string layerName)
Matthew Sloyan2b04ec32023-04-26 11:42:46 +0100111 {
Teresa Charlinf69ae562023-04-27 14:42:23 +0100112 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC(layerName.c_str(),
Matthew Sloyan2b04ec32023-04-26 11:42:46 +0100113 tfLiteContext,
114 IsComparisonSupported,
115 delegateData.m_Backends,
116 isSupported,
117 setBackend,
118 inputTensorInfo0,
119 inputTensorInfo1,
120 outputTensorInfo,
121 descriptor);
122 };
123
124 if (!delegateData.m_Network)
125 {
Mike Kellya2806502023-08-03 10:42:11 +0100126 validateFunc(outputTensorInfo, isSupported, GetOperationName(comparisonOperation));
Matthew Sloyan2b04ec32023-04-26 11:42:46 +0100127 return isSupported ? kTfLiteOk : kTfLiteError;
128 }
129
Mike Kellya2806502023-08-03 10:42:11 +0100130 auto layerName = GetName(descriptor.m_Operation, nodeIndex);
131 armnn::IConnectableLayer* comparisonLayer = delegateData.m_Network->AddComparisonLayer(descriptor,
132 layerName.c_str());
Matthew Sloyan2b04ec32023-04-26 11:42:46 +0100133 comparisonLayer->SetBackendId(setBackend);
134 ARMNN_ASSERT(comparisonLayer != nullptr);
135
136 armnn::IOutputSlot& outputSlot = comparisonLayer->GetOutputSlot(0);
137 outputSlot.SetTensorInfo(outputTensorInfo);
138
139 // try to connect the Constant Inputs if there are any
Mike Kellya2806502023-08-03 10:42:11 +0100140 if (ProcessInputs(comparisonLayer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Matthew Sloyan2b04ec32023-04-26 11:42:46 +0100141 {
142 return kTfLiteError;
143 }
144
145 return Connect(comparisonLayer, tfLiteContext, tfLiteNode, delegateData);
146}
147
Teresa Charlinf69ae562023-04-27 14:42:23 +0100148} // namespace armnnOpaqueDelegate