blob: 06e74ed6355b51eea8192260b9387d85560f4a81 [file] [log] [blame]
James Conroy39825482021-05-27 17:44:50 +01001//
Sadik Armagan90a119b2022-08-05 16:12:49 +01002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
James Conroy39825482021-05-27 17:44:50 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "DelegateUtils.hpp"
9
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
18TfLiteStatus ValidatePreluOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 const armnn::TensorInfo& inputInfo,
21 const armnn::TensorInfo& alphaInfo,
22 const armnn::TensorInfo& outputInfo)
23{
24 bool isSupported = false;
25 auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
26 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000027 FORWARD_LAYER_SUPPORT_FUNC("PRELU",
James Conroy39825482021-05-27 17:44:50 +010028 tfLiteContext,
29 IsPreluSupported,
30 delegateData.m_Backends,
31 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010032 armnn::BackendId(),
James Conroy39825482021-05-27 17:44:50 +010033 inputInfo,
34 alphaInfo,
35 outputInfo);
36 };
37
38 validateFunc(outputInfo, isSupported);
39 return isSupported ? kTfLiteOk : kTfLiteError;
40}
41
42TfLiteStatus VisitPreluOperator(DelegateData& delegateData,
43 TfLiteContext* tfLiteContext,
44 TfLiteNode* tfLiteNode,
45 int nodeIndex,
46 int32_t operatorCode)
47{
48 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
49 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
50
51 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
52
53 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
54 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
55 {
56 return kTfLiteError;
57 }
58
59 const TfLiteTensor& tfLiteAlphaTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
60 if (!IsValid(tfLiteContext, tfLiteAlphaTensor, operatorCode, nodeIndex))
61 {
62 return kTfLiteError;
63 }
64
65 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
66 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
67 {
68 return kTfLiteError;
69 }
70
71 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
72 const armnn::TensorInfo& alphaTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteAlphaTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010073 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
James Conroy39825482021-05-27 17:44:50 +010074
75 if (!delegateData.m_Network)
76 {
77 return ValidatePreluOperator(delegateData,
78 tfLiteContext,
79 inputTensorInfo,
80 alphaTensorInfo,
81 outputTensorInfo);
82 }
83
84 armnn::IConnectableLayer* preluLayer = delegateData.m_Network->AddPreluLayer();
85 ARMNN_ASSERT(preluLayer != nullptr);
86
87 bool isConstantAlpha = tflite::IsConstantTensor(&tfLiteAlphaTensor);
88
89 // Add constant layer for constant alpha
90 if (isConstantAlpha)
91 {
92 auto constAlphaTensor = armnn::ConstTensor(alphaTensorInfo, tfLiteAlphaTensor.data.data);
93
94 armnn::IConnectableLayer* constLayer = delegateData.m_Network->AddConstantLayer(constAlphaTensor);
95 ARMNN_ASSERT(constLayer != nullptr);
96
97 constLayer->GetOutputSlot(0).SetTensorInfo(alphaTensorInfo);
98 constLayer->GetOutputSlot(0).Connect(preluLayer->GetInputSlot(1));
99 }
100
101 armnn::IOutputSlot& outputSlot = preluLayer->GetOutputSlot(0);
102 outputSlot.SetTensorInfo(outputTensorInfo);
103
104 // Connect
105 return Connect(preluLayer, tfLiteNode, delegateData);
106}
107
108} // namespace armnnDelegate