blob: d0db43ea7c03019e3f9c3c3147aedcbaac08ed64 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Sadik Armagan90a119b2022-08-05 16:12:49 +01002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/minimal_logging.h>
12
13namespace armnnDelegate
14{
15
Sadik Armagan4b227bb2021-01-22 10:53:38 +000016TfLiteStatus VisitL2NormalizationOperator(DelegateData& delegateData,
17 TfLiteContext* tfLiteContext,
18 TfLiteNode* tfLiteNode,
19 int nodeIndex,
20 int32_t operatorCode)
Sadik Armagan62483be2020-10-23 17:14:43 +010021{
Sadik Armagan4b227bb2021-01-22 10:53:38 +000022 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
23 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +000024
Sadik Armagan4b227bb2021-01-22 10:53:38 +000025 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
26 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
27 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
28 {
29 return kTfLiteError;
30 }
31
32 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
33 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
34 {
35 return kTfLiteError;
36 }
37
38 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010039 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagan4b227bb2021-01-22 10:53:38 +000040
41 armnn::L2NormalizationDescriptor descriptor;
42 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
43
44 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +010045 armnn::BackendId setBackend;
Sadik Armagan4b227bb2021-01-22 10:53:38 +000046 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
47 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000048 FORWARD_LAYER_SUPPORT_FUNC("L2_NORMALIZATION",
Sadik Armagan4b227bb2021-01-22 10:53:38 +000049 tfLiteContext,
50 IsL2NormalizationSupported,
51 delegateData.m_Backends,
52 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010053 setBackend,
Sadik Armagan4b227bb2021-01-22 10:53:38 +000054 inputTensorInfo,
55 outInfo,
56 descriptor);
57 };
58
59 if (!delegateData.m_Network)
60 {
61 validateFunc(outputTensorInfo, isSupported);
62 return isSupported ? kTfLiteOk : kTfLiteError;
63 }
64
65 // Add a L2Normalization layer
66 armnn::IConnectableLayer* layer = delegateData.m_Network->AddL2NormalizationLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +010067 layer->SetBackendId(setBackend);
Sadik Armagan4b227bb2021-01-22 10:53:38 +000068 ARMNN_ASSERT(layer != nullptr);
69
70 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
71 outputSlot.SetTensorInfo(outputTensorInfo);
72
73 // Connect
74 return Connect(layer, tfLiteNode, delegateData);
75}
76
77
78TfLiteStatus VisitLocalResponseNormalizationOperator(DelegateData& delegateData,
79 TfLiteContext* tfLiteContext,
80 TfLiteNode* tfLiteNode,
81 int nodeIndex,
82 int32_t normalizationOperatorCode)
83{
84 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
85 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
86
87 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
88 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
89 if (!IsValid(tfLiteContext, tfLiteInputTensor, normalizationOperatorCode, nodeIndex))
90 {
91 return kTfLiteError;
92 }
93
94 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
95 if (!IsValid(tfLiteContext, tfLiteOutputTensor, normalizationOperatorCode, nodeIndex))
96 {
97 return kTfLiteError;
98 }
99
100 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +0100101 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagan4b227bb2021-01-22 10:53:38 +0000102
103 armnn::NormalizationDescriptor descriptor;
104 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
105 descriptor.m_NormChannelType = armnn::NormalizationAlgorithmChannel::Across;
106 descriptor.m_NormMethodType = armnn::NormalizationAlgorithmMethod::LocalBrightness;
107
108 auto* params = reinterpret_cast<TfLiteLocalResponseNormParams*>(tfLiteNode->builtin_data);
109 descriptor.m_NormSize = params->radius;
110 descriptor.m_K = params->bias;
111 descriptor.m_Alpha = params->alpha;
112 descriptor.m_Beta = params->beta;
113
114 // ArmNN expects normSize to be the full size of the normalization window
115 descriptor.m_NormSize = 1 + (2 * descriptor.m_NormSize);
116
117 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +0100118 armnn::BackendId setBackend;
Sadik Armagan4b227bb2021-01-22 10:53:38 +0000119 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
120 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000121 FORWARD_LAYER_SUPPORT_FUNC("NORMALIZATION",
Sadik Armagan4b227bb2021-01-22 10:53:38 +0000122 tfLiteContext,
123 IsNormalizationSupported,
124 delegateData.m_Backends,
125 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100126 setBackend,
Sadik Armagan4b227bb2021-01-22 10:53:38 +0000127 inputTensorInfo,
128 outInfo,
129 descriptor);
130 };
131
132 if (!delegateData.m_Network)
133 {
134 validateFunc(outputTensorInfo, isSupported);
135 return isSupported ? kTfLiteOk : kTfLiteError;
136 }
137
138 // Add a Normalization layer
139 armnn::IConnectableLayer* layer = delegateData.m_Network->AddNormalizationLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100140 layer->SetBackendId(setBackend);
Sadik Armagan4b227bb2021-01-22 10:53:38 +0000141 ARMNN_ASSERT(layer != nullptr);
142
143 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
144 outputSlot.SetTensorInfo(outputTensorInfo);
145
146 // Connect
147 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100148}
149
150} // namespace armnnDelegate