blob: 44a443bb4d822686a76bba6af8026160f3584905 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Teresa Charlinf69ae562023-04-27 14:42:23 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12
13std::string GetLayerName(armnn::LogicalBinaryOperation logicalBinaryOperation)
14{
15 std::string layerName = "LOGICAL_BINARY";
16 switch (logicalBinaryOperation)
17 {
18 case armnn::LogicalBinaryOperation::LogicalAnd:
19 layerName += " LOGICAL_AND";
20 break;
21 case armnn::LogicalBinaryOperation::LogicalOr:
22 layerName += " LOGICAL_OR";
23 break;
24 default:
25 layerName += " UNKNOWN";
26 }
27 return layerName;
28}
29
30TfLiteStatus VisitLogicalBinaryOperator(DelegateData& delegateData,
31 TfLiteOpaqueContext* tfLiteContext,
32 TfLiteOpaqueNode* tfLiteNode,
33 int nodeIndex,
34 int32_t logicalOperatorCode,
35 armnn::LogicalBinaryOperation binaryOperation)
36{
37 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
38 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
39
40 // Gather input indices and use to get input tensor.
41 int numInputs = 0;
42 const int* inputTensors;
43 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
44 {
45 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
46 tfLiteContext,
47 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
48 nodeIndex);
49 return kTfLiteError;
50 }
51
52 // Use input indices to get input tensors.
53 const TfLiteOpaqueTensor* tfLiteInputTensor0 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
54 if (!IsValid(tfLiteContext, tfLiteInputTensor0, logicalOperatorCode, nodeIndex))
55 {
56 return kTfLiteError;
57 }
58
59 const TfLiteOpaqueTensor* tfLiteInputTensor1 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
60 if (!IsValid(tfLiteContext, tfLiteInputTensor1, logicalOperatorCode, nodeIndex))
61 {
62 return kTfLiteError;
63 }
64
65 // Gather output indices and use to get output tensors.
66 int numOutputs = 0;
67 const int* outputTensors;
68 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
69 {
70 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
71 tfLiteContext,
72 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
73 nodeIndex);
74 return kTfLiteError;
75 }
76
77 // Use output indices to get output tensor.
78 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
79 if (!IsValid(tfLiteContext, tfLiteOutputTensor, logicalOperatorCode, nodeIndex))
80 {
81 return kTfLiteError;
82 }
83
84 armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor0);
85 armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor1);
86 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
87
88 // Check if we need to expand the dims of any input tensor infos.
89 // This is required for a few of the backends.
90 if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
91 {
92 ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
93 }
94
95 // Setup descriptor and assign operation
96 armnn::LogicalBinaryDescriptor desc;
97 desc.m_Operation = binaryOperation;
98
99 // Check if supported
100 bool isSupported = false;
101 armnn::BackendId setBackend;
102 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported, std::string layerName)
103 {
104 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC(layerName.c_str(),
105 tfLiteContext,
106 IsLogicalBinarySupported,
107 delegateData.m_Backends,
108 isSupported,
109 setBackend,
110 inputTensorInfo0,
111 inputTensorInfo1,
112 outputTensorInfo,
113 desc);
114 };
115
116 if (!delegateData.m_Network)
117 {
118 validateFunc(outputTensorInfo, isSupported, GetLayerName(binaryOperation));
119 return isSupported ? kTfLiteOk : kTfLiteError;
120 }
121
122 armnn::IConnectableLayer* logicalBinaryLayer = delegateData.m_Network->AddLogicalBinaryLayer(desc);
123 logicalBinaryLayer->SetBackendId(setBackend);
124 ARMNN_ASSERT(logicalBinaryLayer != nullptr);
125
126 armnn::IOutputSlot& outputSlot = logicalBinaryLayer->GetOutputSlot(0);
127 outputSlot.SetTensorInfo(outputTensorInfo);
128
129 auto inputsTensorsProcess = ProcessInputs(logicalBinaryLayer,
130 delegateData,
131 tfLiteContext,
132 tfLiteNode);
133 if (inputsTensorsProcess == kTfLiteError)
134 {
135 return inputsTensorsProcess;
136 }
137
138 return Connect(logicalBinaryLayer, tfLiteContext, tfLiteNode, delegateData);
139}
140
141} // namespace armnnOpaqueDelegate