blob: 2e9091b9d735ddf3c206cc2526cba4730a856af1 [file] [log] [blame]
Matthew Sloyanc8eb9552020-11-26 10:54:22 +00001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/minimal_logging.h>
12
13namespace armnnDelegate
14{
15
16TfLiteStatus VisitLogicalBinaryOperator(DelegateData& delegateData,
17 TfLiteContext* tfLiteContext,
18 TfLiteNode* tfLiteNode,
19 int nodeIndex,
20 int32_t logicalOperatorCode,
21 armnn::LogicalBinaryOperation binaryOperation)
22{
23 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25
26 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27 const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
28 if (!IsValid(tfLiteContext, tfLiteInputTensor0, logicalOperatorCode, nodeIndex))
29 {
30 return kTfLiteError;
31 }
32
33 const TfLiteTensor& tfLiteInputTensor1 = tfLiteTensors[tfLiteNode->inputs->data[1]];
34 if (!IsValid(tfLiteContext, tfLiteInputTensor1, logicalOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
40 if (!IsValid(tfLiteContext, tfLiteOutputTensor, logicalOperatorCode, nodeIndex))
41 {
42 return kTfLiteError;
43 }
44
45 armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
46 armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
47 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
48
49 // Setup descriptor and assign operation
50 armnn::LogicalBinaryDescriptor desc;
51 desc.m_Operation = binaryOperation;
52
53 // Check if supported
54 bool isSupported = false;
55 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
56 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000057 FORWARD_LAYER_SUPPORT_FUNC("LOGICAL_BINARY",
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000058 tfLiteContext,
59 IsLogicalBinarySupported,
60 delegateData.m_Backends,
61 isSupported,
62 inputTensorInfo0,
63 inputTensorInfo1,
64 outputTensorInfo,
65 desc);
66 };
67
68 if (!delegateData.m_Network)
69 {
70 validateFunc(outputTensorInfo, isSupported);
71 return isSupported ? kTfLiteOk : kTfLiteError;
72 }
73
74 armnn::IConnectableLayer* logicalBinaryLayer = delegateData.m_Network->AddLogicalBinaryLayer(desc);
75 ARMNN_ASSERT(logicalBinaryLayer != nullptr);
76
77 armnn::IOutputSlot& outputSlot = logicalBinaryLayer->GetOutputSlot(0);
78 outputSlot.SetTensorInfo(outputTensorInfo);
79
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010080 auto inputsTensorsProcess = ProcessInputs(logicalBinaryLayer,
81 delegateData,
82 tfLiteContext,
83 tfLiteNode);
84 if (inputsTensorsProcess == kTfLiteError)
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000085 {
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010086 return inputsTensorsProcess;
Matthew Sloyanc8eb9552020-11-26 10:54:22 +000087 }
88
89 // LogicalBinary operators support broadcasting
90 auto reshapeLayer = BroadcastTensor(inputTensorInfo0,
91 inputTensorInfo1,
92 logicalBinaryLayer,
93 tfLiteContext,
94 tfLiteNode,
95 delegateData);
96 if (!reshapeLayer)
97 {
98 return kTfLiteError;
99 }
100 return kTfLiteOk;
101}
102
103} // namespace armnnDelegate