blob: fa1c547bd4ce43dcee3b41f72ecdfda399d87c25 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/minimal_logging.h>
12
13namespace armnnDelegate
14{
15
Sadik Armagan4b227bb2021-01-22 10:53:38 +000016TfLiteStatus VisitL2NormalizationOperator(DelegateData& delegateData,
17 TfLiteContext* tfLiteContext,
18 TfLiteNode* tfLiteNode,
19 int nodeIndex,
20 int32_t operatorCode)
Sadik Armagan62483be2020-10-23 17:14:43 +010021{
Sadik Armagan4b227bb2021-01-22 10:53:38 +000022 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
23 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +000024
Sadik Armagan4b227bb2021-01-22 10:53:38 +000025 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
26 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
27 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
28 {
29 return kTfLiteError;
30 }
31
32 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
33 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
34 {
35 return kTfLiteError;
36 }
37
38 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
39 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
40
41 armnn::L2NormalizationDescriptor descriptor;
42 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
43
44 bool isSupported = false;
45 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
46 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000047 FORWARD_LAYER_SUPPORT_FUNC("L2_NORMALIZATION",
Sadik Armagan4b227bb2021-01-22 10:53:38 +000048 tfLiteContext,
49 IsL2NormalizationSupported,
50 delegateData.m_Backends,
51 isSupported,
52 inputTensorInfo,
53 outInfo,
54 descriptor);
55 };
56
57 if (!delegateData.m_Network)
58 {
59 validateFunc(outputTensorInfo, isSupported);
60 return isSupported ? kTfLiteOk : kTfLiteError;
61 }
62
63 // Add a L2Normalization layer
64 armnn::IConnectableLayer* layer = delegateData.m_Network->AddL2NormalizationLayer(descriptor);
65 ARMNN_ASSERT(layer != nullptr);
66
67 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
68 outputSlot.SetTensorInfo(outputTensorInfo);
69
70 // Connect
71 return Connect(layer, tfLiteNode, delegateData);
72}
73
74
75TfLiteStatus VisitLocalResponseNormalizationOperator(DelegateData& delegateData,
76 TfLiteContext* tfLiteContext,
77 TfLiteNode* tfLiteNode,
78 int nodeIndex,
79 int32_t normalizationOperatorCode)
80{
81 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
82 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
83
84 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
85 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
86 if (!IsValid(tfLiteContext, tfLiteInputTensor, normalizationOperatorCode, nodeIndex))
87 {
88 return kTfLiteError;
89 }
90
91 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
92 if (!IsValid(tfLiteContext, tfLiteOutputTensor, normalizationOperatorCode, nodeIndex))
93 {
94 return kTfLiteError;
95 }
96
97 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
98 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
99
100 armnn::NormalizationDescriptor descriptor;
101 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
102 descriptor.m_NormChannelType = armnn::NormalizationAlgorithmChannel::Across;
103 descriptor.m_NormMethodType = armnn::NormalizationAlgorithmMethod::LocalBrightness;
104
105 auto* params = reinterpret_cast<TfLiteLocalResponseNormParams*>(tfLiteNode->builtin_data);
106 descriptor.m_NormSize = params->radius;
107 descriptor.m_K = params->bias;
108 descriptor.m_Alpha = params->alpha;
109 descriptor.m_Beta = params->beta;
110
111 // ArmNN expects normSize to be the full size of the normalization window
112 descriptor.m_NormSize = 1 + (2 * descriptor.m_NormSize);
113
114 bool isSupported = false;
115 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
116 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000117 FORWARD_LAYER_SUPPORT_FUNC("NORMALIZATION",
Sadik Armagan4b227bb2021-01-22 10:53:38 +0000118 tfLiteContext,
119 IsNormalizationSupported,
120 delegateData.m_Backends,
121 isSupported,
122 inputTensorInfo,
123 outInfo,
124 descriptor);
125 };
126
127 if (!delegateData.m_Network)
128 {
129 validateFunc(outputTensorInfo, isSupported);
130 return isSupported ? kTfLiteOk : kTfLiteError;
131 }
132
133 // Add a Normalization layer
134 armnn::IConnectableLayer* layer = delegateData.m_Network->AddNormalizationLayer(descriptor);
135 ARMNN_ASSERT(layer != nullptr);
136
137 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
138 outputSlot.SetTensorInfo(outputTensorInfo);
139
140 // Connect
141 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100142}
143
144} // namespace armnnDelegate