blob: ef2e524369dc2d7bd53aa3bdeb0b14c557671f00 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/minimal_logging.h>
12
13namespace armnnDelegate
14{
15
Sadik Armagan4b227bb2021-01-22 10:53:38 +000016TfLiteStatus VisitL2NormalizationOperator(DelegateData& delegateData,
17 TfLiteContext* tfLiteContext,
18 TfLiteNode* tfLiteNode,
19 int nodeIndex,
20 int32_t operatorCode)
Sadik Armagan62483be2020-10-23 17:14:43 +010021{
Sadik Armagan4b227bb2021-01-22 10:53:38 +000022 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
23 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +000024
Sadik Armagan4b227bb2021-01-22 10:53:38 +000025 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
26 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
27 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
28 {
29 return kTfLiteError;
30 }
31
32 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
33 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
34 {
35 return kTfLiteError;
36 }
37
38 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010039 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagan4b227bb2021-01-22 10:53:38 +000040
41 armnn::L2NormalizationDescriptor descriptor;
42 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
43
44 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +010045 armnn::BackendId setBackend;
Sadik Armagan4b227bb2021-01-22 10:53:38 +000046 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
47 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000048 FORWARD_LAYER_SUPPORT_FUNC("L2_NORMALIZATION",
Sadik Armagan4b227bb2021-01-22 10:53:38 +000049 tfLiteContext,
50 IsL2NormalizationSupported,
51 delegateData.m_Backends,
52 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010053 setBackend,
Sadik Armagan4b227bb2021-01-22 10:53:38 +000054 inputTensorInfo,
55 outInfo,
56 descriptor);
57 };
58
59 if (!delegateData.m_Network)
60 {
61 validateFunc(outputTensorInfo, isSupported);
62 return isSupported ? kTfLiteOk : kTfLiteError;
63 }
64
65 // Add a L2Normalization layer
66 armnn::IConnectableLayer* layer = delegateData.m_Network->AddL2NormalizationLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +010067 layer->SetBackendId(setBackend);
Sadik Armagan4b227bb2021-01-22 10:53:38 +000068 ARMNN_ASSERT(layer != nullptr);
69
70 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
71 outputSlot.SetTensorInfo(outputTensorInfo);
72
Ryan OShea4c231de2023-01-17 15:19:20 +000073 // try to connect the Constant Inputs if there are any
74 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
75 {
76 return kTfLiteError;
77 }
78
Sadik Armagan4b227bb2021-01-22 10:53:38 +000079 // Connect
80 return Connect(layer, tfLiteNode, delegateData);
81}
82
83
84TfLiteStatus VisitLocalResponseNormalizationOperator(DelegateData& delegateData,
85 TfLiteContext* tfLiteContext,
86 TfLiteNode* tfLiteNode,
87 int nodeIndex,
88 int32_t normalizationOperatorCode)
89{
90 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
91 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
92
93 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
94 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
95 if (!IsValid(tfLiteContext, tfLiteInputTensor, normalizationOperatorCode, nodeIndex))
96 {
97 return kTfLiteError;
98 }
99
100 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
101 if (!IsValid(tfLiteContext, tfLiteOutputTensor, normalizationOperatorCode, nodeIndex))
102 {
103 return kTfLiteError;
104 }
105
106 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +0100107 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagan4b227bb2021-01-22 10:53:38 +0000108
109 armnn::NormalizationDescriptor descriptor;
110 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
111 descriptor.m_NormChannelType = armnn::NormalizationAlgorithmChannel::Across;
112 descriptor.m_NormMethodType = armnn::NormalizationAlgorithmMethod::LocalBrightness;
113
114 auto* params = reinterpret_cast<TfLiteLocalResponseNormParams*>(tfLiteNode->builtin_data);
115 descriptor.m_NormSize = params->radius;
116 descriptor.m_K = params->bias;
117 descriptor.m_Alpha = params->alpha;
118 descriptor.m_Beta = params->beta;
119
120 // ArmNN expects normSize to be the full size of the normalization window
121 descriptor.m_NormSize = 1 + (2 * descriptor.m_NormSize);
122
123 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +0100124 armnn::BackendId setBackend;
Sadik Armagan4b227bb2021-01-22 10:53:38 +0000125 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
126 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000127 FORWARD_LAYER_SUPPORT_FUNC("NORMALIZATION",
Sadik Armagan4b227bb2021-01-22 10:53:38 +0000128 tfLiteContext,
129 IsNormalizationSupported,
130 delegateData.m_Backends,
131 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100132 setBackend,
Sadik Armagan4b227bb2021-01-22 10:53:38 +0000133 inputTensorInfo,
134 outInfo,
135 descriptor);
136 };
137
138 if (!delegateData.m_Network)
139 {
140 validateFunc(outputTensorInfo, isSupported);
141 return isSupported ? kTfLiteOk : kTfLiteError;
142 }
143
144 // Add a Normalization layer
145 armnn::IConnectableLayer* layer = delegateData.m_Network->AddNormalizationLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100146 layer->SetBackendId(setBackend);
Sadik Armagan4b227bb2021-01-22 10:53:38 +0000147 ARMNN_ASSERT(layer != nullptr);
148
149 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
150 outputSlot.SetTensorInfo(outputTensorInfo);
151
Ryan OShea4c231de2023-01-17 15:19:20 +0000152 // try to connect the Constant Inputs if there are any
153 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
154 {
155 return kTfLiteError;
156 }
157
Sadik Armagan4b227bb2021-01-22 10:53:38 +0000158 // Connect
159 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100160}
161
162} // namespace armnnDelegate