blob: 07b55c3e32374aa97f140347a5e831b7970b9eb6 [file] [log] [blame]
Matthew Sloyanc8eb9552020-11-26 10:54:22 +00001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/minimal_logging.h>
12
13namespace armnnDelegate
14{
15
16TfLiteStatus VisitLogicalBinaryOperator(DelegateData& delegateData,
17 TfLiteContext* tfLiteContext,
18 TfLiteNode* tfLiteNode,
19 int nodeIndex,
20 int32_t logicalOperatorCode,
21 armnn::LogicalBinaryOperation binaryOperation)
22{
23 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25
26 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27 const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
28 if (!IsValid(tfLiteContext, tfLiteInputTensor0, logicalOperatorCode, nodeIndex))
29 {
30 return kTfLiteError;
31 }
32
33 const TfLiteTensor& tfLiteInputTensor1 = tfLiteTensors[tfLiteNode->inputs->data[1]];
34 if (!IsValid(tfLiteContext, tfLiteInputTensor1, logicalOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
40 if (!IsValid(tfLiteContext, tfLiteOutputTensor, logicalOperatorCode, nodeIndex))
41 {
42 return kTfLiteError;
43 }
44
45 armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
46 armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
47 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
48
49 // Setup descriptor and assign operation
50 armnn::LogicalBinaryDescriptor desc;
51 desc.m_Operation = binaryOperation;
52
53 // Check if supported
54 bool isSupported = false;
55 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
56 {
57 FORWARD_LAYER_SUPPORT_FUNC(__func__,
58 tfLiteContext,
59 IsLogicalBinarySupported,
60 delegateData.m_Backends,
61 isSupported,
62 inputTensorInfo0,
63 inputTensorInfo1,
64 outputTensorInfo,
65 desc);
66 };
67
68 if (!delegateData.m_Network)
69 {
70 validateFunc(outputTensorInfo, isSupported);
71 return isSupported ? kTfLiteOk : kTfLiteError;
72 }
73
74 armnn::IConnectableLayer* logicalBinaryLayer = delegateData.m_Network->AddLogicalBinaryLayer(desc);
75 ARMNN_ASSERT(logicalBinaryLayer != nullptr);
76
77 armnn::IOutputSlot& outputSlot = logicalBinaryLayer->GetOutputSlot(0);
78 outputSlot.SetTensorInfo(outputTensorInfo);
79
80 if(tflite::IsConstantTensor(&tfLiteInputTensor0))
81 {
82 auto status = ConnectConstant(logicalBinaryLayer,
83 inputTensorInfo0,
84 tfLiteContext,
85 tfLiteInputTensor0,
86 delegateData,
87 tfLiteNode->inputs->data[0]);
88 if (status == kTfLiteError)
89 {
90 return status;
91 }
92 }
93
94 if(tflite::IsConstantTensor(&tfLiteInputTensor1))
95 {
96 auto status = ConnectConstant(logicalBinaryLayer,
97 inputTensorInfo1,
98 tfLiteContext,
99 tfLiteInputTensor1,
100 delegateData,
101 tfLiteNode->inputs->data[1]);
102 if (status == kTfLiteError)
103 {
104 return status;
105 }
106 }
107
108 // LogicalBinary operators support broadcasting
109 auto reshapeLayer = BroadcastTensor(inputTensorInfo0,
110 inputTensorInfo1,
111 logicalBinaryLayer,
112 tfLiteContext,
113 tfLiteNode,
114 delegateData);
115 if (!reshapeLayer)
116 {
117 return kTfLiteError;
118 }
119 return kTfLiteOk;
120}
121
122} // namespace armnnDelegate