blob: aef70e8e5bedda7ec4573154f203ea5080e1eda2 [file] [log] [blame]
James Conroy39825482021-05-27 17:44:50 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "DelegateUtils.hpp"
9
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
18TfLiteStatus ValidatePreluOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 const armnn::TensorInfo& inputInfo,
21 const armnn::TensorInfo& alphaInfo,
22 const armnn::TensorInfo& outputInfo)
23{
24 bool isSupported = false;
25 auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
26 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000027 FORWARD_LAYER_SUPPORT_FUNC("PRELU",
James Conroy39825482021-05-27 17:44:50 +010028 tfLiteContext,
29 IsPreluSupported,
30 delegateData.m_Backends,
31 isSupported,
32 inputInfo,
33 alphaInfo,
34 outputInfo);
35 };
36
37 validateFunc(outputInfo, isSupported);
38 return isSupported ? kTfLiteOk : kTfLiteError;
39}
40
41TfLiteStatus VisitPreluOperator(DelegateData& delegateData,
42 TfLiteContext* tfLiteContext,
43 TfLiteNode* tfLiteNode,
44 int nodeIndex,
45 int32_t operatorCode)
46{
47 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
48 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
49
50 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
51
52 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
53 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
54 {
55 return kTfLiteError;
56 }
57
58 const TfLiteTensor& tfLiteAlphaTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
59 if (!IsValid(tfLiteContext, tfLiteAlphaTensor, operatorCode, nodeIndex))
60 {
61 return kTfLiteError;
62 }
63
64 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
65 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
66 {
67 return kTfLiteError;
68 }
69
70 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
71 const armnn::TensorInfo& alphaTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteAlphaTensor);
72 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
73
74 if (!delegateData.m_Network)
75 {
76 return ValidatePreluOperator(delegateData,
77 tfLiteContext,
78 inputTensorInfo,
79 alphaTensorInfo,
80 outputTensorInfo);
81 }
82
83 armnn::IConnectableLayer* preluLayer = delegateData.m_Network->AddPreluLayer();
84 ARMNN_ASSERT(preluLayer != nullptr);
85
86 bool isConstantAlpha = tflite::IsConstantTensor(&tfLiteAlphaTensor);
87
88 // Add constant layer for constant alpha
89 if (isConstantAlpha)
90 {
91 auto constAlphaTensor = armnn::ConstTensor(alphaTensorInfo, tfLiteAlphaTensor.data.data);
92
93 armnn::IConnectableLayer* constLayer = delegateData.m_Network->AddConstantLayer(constAlphaTensor);
94 ARMNN_ASSERT(constLayer != nullptr);
95
96 constLayer->GetOutputSlot(0).SetTensorInfo(alphaTensorInfo);
97 constLayer->GetOutputSlot(0).Connect(preluLayer->GetInputSlot(1));
98 }
99
100 armnn::IOutputSlot& outputSlot = preluLayer->GetOutputSlot(0);
101 outputSlot.SetTensorInfo(outputTensorInfo);
102
103 // Connect
104 return Connect(preluLayer, tfLiteNode, delegateData);
105}
106
107} // namespace armnnDelegate