blob: 1c9f06d8b8a250840fed97d69762b6fff17a762b [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Matthew Sloyan0bd4c622023-04-27 11:48:26 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12
13TfLiteStatus ValidatePreluOperator(DelegateData& delegateData,
14 TfLiteOpaqueContext* tfLiteContext,
15 const armnn::TensorInfo& inputInfo,
16 const armnn::TensorInfo& alphaInfo,
17 const armnn::TensorInfo& outputInfo)
18{
19 bool isSupported = false;
20 auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
21 {
22 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("PRELU",
23 tfLiteContext,
24 IsPreluSupported,
25 delegateData.m_Backends,
26 isSupported,
27 armnn::BackendId(),
28 inputInfo,
29 alphaInfo,
30 outputInfo);
31 };
32
33 validateFunc(outputInfo, isSupported);
34 return isSupported ? kTfLiteOk : kTfLiteError;
35}
36
37TfLiteStatus VisitPreluOperator(DelegateData& delegateData,
38 TfLiteOpaqueContext* tfLiteContext,
39 TfLiteOpaqueNode* tfLiteNode,
40 int nodeIndex,
41 int32_t operatorCode)
42{
43 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
44 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
45
46 // Gather input indices and use to get input tensor.
47 int numInputs = 0;
48 const int* inputTensors;
49 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
50 {
51 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
52 tfLiteContext,
53 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
54 nodeIndex);
55 return kTfLiteError;
56 }
57
58 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
59 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
60 {
61 return kTfLiteError;
62 }
63
64 const TfLiteOpaqueTensor* tfLiteAlphaTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
65 if (!IsValid(tfLiteContext, tfLiteAlphaTensor, operatorCode, nodeIndex))
66 {
67 return kTfLiteError;
68 }
69
70 // Gather output indices and use to get output tensors.
71 int numOutputs = 0;
72 const int* outputTensors;
73 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
74 {
75 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
76 tfLiteContext,
77 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
78 nodeIndex);
79 return kTfLiteError;
80 }
81
82 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
83 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
84 {
85 return kTfLiteError;
86 }
87
88 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
89 const armnn::TensorInfo& alphaTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteAlphaTensor);
90 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
91
92 if (!delegateData.m_Network)
93 {
94 return ValidatePreluOperator(delegateData,
95 tfLiteContext,
96 inputTensorInfo,
97 alphaTensorInfo,
98 outputTensorInfo);
99 }
100
Mike Kellya2806502023-08-03 10:42:11 +0100101 auto layerName = GetName(armnn::LayerType::Prelu, nodeIndex);
102 armnn::IConnectableLayer* preluLayer = delegateData.m_Network->AddPreluLayer(layerName.c_str());
Matthew Sloyan0bd4c622023-04-27 11:48:26 +0100103 ARMNN_ASSERT(preluLayer != nullptr);
104
105 bool isConstantAlpha = IsConstantTensor(tfLiteAlphaTensor);
106
107 // Add constant layer for constant alpha
108 if (isConstantAlpha)
109 {
110 auto constAlphaTensor = armnn::ConstTensor(alphaTensorInfo, TfLiteOpaqueTensorData(tfLiteAlphaTensor));
111
Mike Kellya2806502023-08-03 10:42:11 +0100112 auto alphaName = GetName(armnn::LayerType::Constant, nodeIndex, "Alpha");
113 armnn::IConnectableLayer* constLayer = delegateData.m_Network->AddConstantLayer(constAlphaTensor,
114 alphaName.c_str());
Matthew Sloyan0bd4c622023-04-27 11:48:26 +0100115 ARMNN_ASSERT(constLayer != nullptr);
116
117 constLayer->GetOutputSlot(0).SetTensorInfo(alphaTensorInfo);
118 constLayer->GetOutputSlot(0).Connect(preluLayer->GetInputSlot(1));
119 }
120
121 armnn::IOutputSlot& outputSlot = preluLayer->GetOutputSlot(0);
122 outputSlot.SetTensorInfo(outputTensorInfo);
123
124 // Connect
125 return Connect(preluLayer, tfLiteContext, tfLiteNode, delegateData);
126}
127
128} // namespace armnnOpaqueDelegate