blob: 1c55c2e9b5a21fb424f5fe595685f31c877fb392 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
Finn Williams6f9f9902020-11-13 13:23:15 +00009
Sadik Armagan62483be2020-10-23 17:14:43 +010010#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
David Monahan0cf84422020-11-16 15:53:03 +000018TfLiteStatus ValidateActivationOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 const armnn::TensorInfo& inputInfo,
21 const armnn::TensorInfo& outputInfo,
22 armnn::ActivationDescriptor& activationDesc)
23{
24 bool isSupported = false;
Keith Davis892fafe2020-11-26 17:40:35 +000025 auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
David Monahan0cf84422020-11-16 15:53:03 +000026 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000027 FORWARD_LAYER_SUPPORT_FUNC("ACTIVATION",
David Monahan0cf84422020-11-16 15:53:03 +000028 tfLiteContext,
29 IsActivationSupported,
30 delegateData.m_Backends,
31 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010032 armnn::BackendId(),
David Monahan0cf84422020-11-16 15:53:03 +000033 inputInfo,
34 outputInfo,
35 activationDesc);
36 };
37
38 validateFunc(outputInfo, isSupported);
39 return isSupported ? kTfLiteOk : kTfLiteError;
40}
41
Sadik Armagan62483be2020-10-23 17:14:43 +010042TfLiteStatus VisitActivationOperator(DelegateData& delegateData,
43 TfLiteContext* tfLiteContext,
44 TfLiteNode* tfLiteNode,
45 int nodeIndex,
David Monahan0cf84422020-11-16 15:53:03 +000046 int32_t operatorCode)
Sadik Armagan62483be2020-10-23 17:14:43 +010047{
David Monahan0cf84422020-11-16 15:53:03 +000048 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
49 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +000050
David Monahan0cf84422020-11-16 15:53:03 +000051 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
52 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
Matthew Sloyan7515d072020-12-16 12:50:01 +000053 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
David Monahan0cf84422020-11-16 15:53:03 +000054 {
David Monahan0cf84422020-11-16 15:53:03 +000055 return kTfLiteError;
56 }
Matthew Sloyan7515d072020-12-16 12:50:01 +000057
David Monahan0cf84422020-11-16 15:53:03 +000058 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
Matthew Sloyan7515d072020-12-16 12:50:01 +000059 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
David Monahan0cf84422020-11-16 15:53:03 +000060 {
David Monahan0cf84422020-11-16 15:53:03 +000061 return kTfLiteError;
62 }
Finn Williams6f9f9902020-11-13 13:23:15 +000063
David Monahan0cf84422020-11-16 15:53:03 +000064 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010065 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
David Monahan0cf84422020-11-16 15:53:03 +000066
67 armnn::ActivationDescriptor activationDesc;
68 switch(operatorCode)
69 {
70 case kTfLiteBuiltinRelu:
71 {
72 activationDesc.m_Function = armnn::ActivationFunction::ReLu;
73 break;
74 }
75 case kTfLiteBuiltinRelu6:
76 {
77 activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
78 activationDesc.m_A = 6.0f;
79 break;
80 }
81 case kTfLiteBuiltinLogistic:
82 {
83 activationDesc.m_Function = armnn::ActivationFunction::Sigmoid;
84 break;
85 }
86 case kTfLiteBuiltinTanh:
87 {
88 activationDesc.m_Function = armnn::ActivationFunction::TanH;
89 activationDesc.m_A = 1.0f;
90 activationDesc.m_B = 1.0f;
91 break;
92 }
Matthew Sloyan7515d072020-12-16 12:50:01 +000093 case kTfLiteBuiltinElu:
94 {
95 activationDesc.m_Function = armnn::ActivationFunction::Elu;
96 activationDesc.m_A = 1.0f;
97 break;
98 }
99 case kTfLiteBuiltinHardSwish:
100 {
101 activationDesc.m_Function = armnn::ActivationFunction::HardSwish;
102 break;
103 }
Tianle Chengae931732023-07-28 11:53:04 +0100104 case kTfLiteBuiltinLeakyRelu:
105 {
106 // Get the alpha param from builtin data
107 auto* leakyReluParameters = reinterpret_cast<TfLiteLeakyReluParams*>(tfLiteNode->builtin_data);
108 activationDesc.m_Function = armnn::ActivationFunction::LeakyReLu;
109 activationDesc.m_A = leakyReluParameters->alpha;
110 break;
111 }
Teresa Charlin077cddb2023-09-15 15:19:21 +0100112 case kTfLiteBuiltinGelu:
113 {
114 activationDesc.m_Function = armnn::ActivationFunction::Gelu;
115 break;
116 }
David Monahan0cf84422020-11-16 15:53:03 +0000117 default:
118 {
119 return kTfLiteError;
120 }
121 }
122 if (!delegateData.m_Network)
123 {
124 return ValidateActivationOperator(delegateData,
125 tfLiteContext,
126 inputTensorInfo,
127 outputTensorInfo,
128 activationDesc);
129 }
Mike Kelly07169c82023-08-02 13:23:09 +0100130 auto layerName = GetLayerName(activationDesc.m_Function, nodeIndex);
131 armnn::IConnectableLayer* activationLayer = delegateData.m_Network->AddActivationLayer(activationDesc,
132 layerName.c_str());
David Monahan0cf84422020-11-16 15:53:03 +0000133 ARMNN_ASSERT(activationLayer != nullptr);
134
135 armnn::IOutputSlot& outputSlot = activationLayer->GetOutputSlot(0);
136 outputSlot.SetTensorInfo(outputTensorInfo);
137
Ryan OShea4c231de2023-01-17 15:19:20 +0000138 // try to connect the Constant Inputs if there are any
Mike Kelly07169c82023-08-02 13:23:09 +0100139 if (ProcessInputs(activationLayer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Ryan OShea4c231de2023-01-17 15:19:20 +0000140 {
141 return kTfLiteError;
142 }
143
David Monahan0cf84422020-11-16 15:53:03 +0000144 // Connect
145 return Connect(activationLayer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100146}
147
148} // namespace armnnDelegate