blob: 193e3f03768d76e41157a51d89dd8ff14c63ec09 [file] [log] [blame]
James Conroy39825482021-05-27 17:44:50 +01001//
Teresa Charlinad1b3d72023-03-14 12:10:28 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
James Conroy39825482021-05-27 17:44:50 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
James Conroy39825482021-05-27 17:44:50 +01009
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
18TfLiteStatus ValidatePreluOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 const armnn::TensorInfo& inputInfo,
21 const armnn::TensorInfo& alphaInfo,
22 const armnn::TensorInfo& outputInfo)
23{
24 bool isSupported = false;
25 auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
26 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000027 FORWARD_LAYER_SUPPORT_FUNC("PRELU",
James Conroy39825482021-05-27 17:44:50 +010028 tfLiteContext,
29 IsPreluSupported,
30 delegateData.m_Backends,
31 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010032 armnn::BackendId(),
James Conroy39825482021-05-27 17:44:50 +010033 inputInfo,
34 alphaInfo,
35 outputInfo);
36 };
37
38 validateFunc(outputInfo, isSupported);
39 return isSupported ? kTfLiteOk : kTfLiteError;
40}
41
42TfLiteStatus VisitPreluOperator(DelegateData& delegateData,
43 TfLiteContext* tfLiteContext,
44 TfLiteNode* tfLiteNode,
45 int nodeIndex,
46 int32_t operatorCode)
47{
48 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
49 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
50
51 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
52
53 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
54 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
55 {
56 return kTfLiteError;
57 }
58
59 const TfLiteTensor& tfLiteAlphaTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
60 if (!IsValid(tfLiteContext, tfLiteAlphaTensor, operatorCode, nodeIndex))
61 {
62 return kTfLiteError;
63 }
64
65 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
66 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
67 {
68 return kTfLiteError;
69 }
70
71 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
72 const armnn::TensorInfo& alphaTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteAlphaTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010073 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
James Conroy39825482021-05-27 17:44:50 +010074
75 if (!delegateData.m_Network)
76 {
77 return ValidatePreluOperator(delegateData,
78 tfLiteContext,
79 inputTensorInfo,
80 alphaTensorInfo,
81 outputTensorInfo);
82 }
83
Mike Kelly07169c82023-08-02 13:23:09 +010084 auto layerName = GetLayerName(armnn::LayerType::Prelu, nodeIndex);
85 armnn::IConnectableLayer* preluLayer = delegateData.m_Network->AddPreluLayer(layerName.c_str());
James Conroy39825482021-05-27 17:44:50 +010086 ARMNN_ASSERT(preluLayer != nullptr);
87
88 bool isConstantAlpha = tflite::IsConstantTensor(&tfLiteAlphaTensor);
89
90 // Add constant layer for constant alpha
91 if (isConstantAlpha)
92 {
93 auto constAlphaTensor = armnn::ConstTensor(alphaTensorInfo, tfLiteAlphaTensor.data.data);
94
Mike Kelly07169c82023-08-02 13:23:09 +010095 auto alphaName = GetLayerName(armnn::LayerType::Constant, nodeIndex, "Alpha");
96 armnn::IConnectableLayer* constLayer = delegateData.m_Network->AddConstantLayer(constAlphaTensor,
97 alphaName.c_str());
James Conroy39825482021-05-27 17:44:50 +010098 ARMNN_ASSERT(constLayer != nullptr);
99
100 constLayer->GetOutputSlot(0).SetTensorInfo(alphaTensorInfo);
101 constLayer->GetOutputSlot(0).Connect(preluLayer->GetInputSlot(1));
102 }
103
104 armnn::IOutputSlot& outputSlot = preluLayer->GetOutputSlot(0);
105 outputSlot.SetTensorInfo(outputTensorInfo);
106
107 // Connect
108 return Connect(preluLayer, tfLiteNode, delegateData);
109}
110
111} // namespace armnnDelegate