blob: 53251f7c5598192807b8da5713a87e8ff924e62d [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagan6e36a642020-11-10 21:18:41 +00008#include "DelegateUtils.hpp"
Finn Williams6f9f9902020-11-13 13:23:15 +00009#include <armnn/utility/IgnoreUnused.hpp>
Sadik Armagan6e36a642020-11-10 21:18:41 +000010
Sadik Armagan62483be2020-10-23 17:14:43 +010011#include <tensorflow/lite/builtin_ops.h>
12#include <tensorflow/lite/c/builtin_op_data.h>
13#include <tensorflow/lite/c/common.h>
14#include <tensorflow/lite/minimal_logging.h>
15
16namespace armnnDelegate
17{
18
19TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t operatorCode)
24{
Sadik Armagan6e36a642020-11-10 21:18:41 +000025 auto numInputs = tfLiteNode->inputs->size;
26 if (numInputs < 2)
27 {
28 TF_LITE_MAYBE_KERNEL_LOG(
29 tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
30 2, numInputs, nodeIndex);
31 return kTfLiteError;
32 }
33 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
34 bool biasEnabled = (numInputs == 3);
35
36 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
37 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
38 if(!IsValid(&tfLiteTensors[tfLiteNode->inputs->data[0]]))
39 {
40 TF_LITE_MAYBE_KERNEL_LOG(
41 tfLiteContext,
42 "TfLiteArmnnDelegate: Invalid input tensor in operator #%d node #%d: ",
43 operatorCode, nodeIndex);
44 return kTfLiteError;
45 }
46 if (IsDynamicTensor(tfLiteInputTensor))
47 {
48 TF_LITE_MAYBE_KERNEL_LOG(
49 tfLiteContext,
50 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in node #%d: ",
51 nodeIndex);
52 return kTfLiteError;
53 }
54 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
55 if(!IsValid(&tfLiteOutputTensor))
56 {
57 TF_LITE_MAYBE_KERNEL_LOG(
58 tfLiteContext,
59 "TfLiteArmnnDelegate: Invalid output tensor in operator #%d node #%d: ",
60 operatorCode, nodeIndex);
61 return kTfLiteError;
62 }
63 if (IsDynamicTensor(tfLiteOutputTensor))
64 {
65 TF_LITE_MAYBE_KERNEL_LOG(
66 tfLiteContext,
67 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in node #%d: ",
68 nodeIndex);
69 return kTfLiteError;
70 }
71
72 const TfLiteTensor& tfLiteWeightsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
73 if(!IsValid(&tfLiteWeightsTensor))
74 {
75 TF_LITE_MAYBE_KERNEL_LOG(
76 tfLiteContext,
77 "TfLiteArmnnDelegate: Invalid weights tensor in operator #%d node #%d: ",
78 operatorCode, nodeIndex);
79 return kTfLiteError;
80 }
81 if (IsDynamicTensor(tfLiteWeightsTensor))
82 {
83 TF_LITE_MAYBE_KERNEL_LOG(
84 tfLiteContext,
85 "TfLiteArmnnDelegate: Dynamic weight tensors are not supported in node #%d: ",
86 nodeIndex);
87 return kTfLiteError;
88 }
89
90 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
91 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
92
93 armnn::TensorInfo weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
94 // Fully Connected Layer accepts two dimensional weights input
95 int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions());
96 if (weightsDimension != 2)
97 {
98 TF_LITE_MAYBE_KERNEL_LOG(
99 tfLiteContext,
100 "TfLiteArmnnDelegate: Dimension #$d for Fully Connected weights is not supported by Armnn"
101 " in operator #%d node #%d: ", weightsDimension, operatorCode, nodeIndex);
102 return kTfLiteError;
103 }
104
105 armnn::TensorInfo biasTensorInfo;
106 if (biasEnabled)
107 {
108 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
109 if(!IsValid(&tfLiteBiasTensor))
110 {
111 TF_LITE_MAYBE_KERNEL_LOG(
112 tfLiteContext,
113 "TfLiteArmnnDelegate: Invalid bias tensor in operator #%d node #%d: ",
114 operatorCode, nodeIndex);
115 return kTfLiteError;
116 }
117 if (IsDynamicTensor(tfLiteBiasTensor))
118 {
119 TF_LITE_MAYBE_KERNEL_LOG(
120 tfLiteContext,
121 "TfLiteArmnnDelegate: Dynamic bias tensors are not supported in node #%d: ",
122 nodeIndex);
123 return kTfLiteError;
124 }
125 biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
126 }
127 else
128 {
129 biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
130 }
131
Narumol Prangnawarat66da7512020-11-20 14:50:54 +0000132 armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
133
134 if (inputTensorInfo.GetNumDimensions() > 2)
135 {
136 // Calculate reshape to flatten to 2D [batch_size, input_size]
137 std::vector<unsigned int> reshapedDimensions(2);
138 reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
139 reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
140
141 if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
142 {
143 TF_LITE_MAYBE_KERNEL_LOG(
144 tfLiteContext,
145 "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
146 reshapedDimensions[1], operatorCode, nodeIndex);
147 return kTfLiteError;
148 }
149
150 reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
151 }
152
Sadik Armagan6e36a642020-11-10 21:18:41 +0000153 armnn::FullyConnectedDescriptor descriptor;
154 descriptor.m_TransposeWeightMatrix = true;
155 descriptor.m_BiasEnabled = biasEnabled;
156
157 bool isSupported = false;
158 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
159 {
160 FORWARD_LAYER_SUPPORT_FUNC(__func__,
161 tfLiteContext,
162 IsFullyConnectedSupported,
163 delegateData.m_Backends,
164 isSupported,
Narumol Prangnawarat66da7512020-11-20 14:50:54 +0000165 reshapedTensorInfo,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000166 outputTensorInfo,
167 weightsTensorInfo,
168 biasTensorInfo,
169 descriptor);
170 };
171
172 if (!delegateData.m_Network)
173 {
174 validateFunc(outputTensorInfo, isSupported);
175 return isSupported ? kTfLiteOk : kTfLiteError;
176 }
177
178 auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
179 weightsTensorInfo,
180 armnn::Optional<armnn::PermutationVector&>());
181
182 armnn::IConnectableLayer* layer = nullptr;
183 if (biasEnabled)
184 {
185 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
186 auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
187 biasTensorInfo,
188 armnn::Optional<armnn::PermutationVector&>());
189 layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor,
Sadik Armagan4189cc52020-11-11 18:01:48 +0000190 weightsTensor,
191 armnn::Optional<armnn::ConstTensor>(biasTensor));
Sadik Armagan6e36a642020-11-10 21:18:41 +0000192 }
193 else
194 {
195 layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor,
Sadik Armagan4189cc52020-11-11 18:01:48 +0000196 weightsTensor,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000197 armnn::EmptyOptional());
198 }
199 ARMNN_ASSERT(layer != nullptr);
200
201 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
202 outputSlot.SetTensorInfo(outputTensorInfo);
203
204 armnn::IConnectableLayer* reshapeLayer = nullptr;
205 if (inputTensorInfo.GetNumDimensions() > 2)
206 {
207 // Add reshape to flatten to 2D [batch_size, input_size]
Sadik Armagan6e36a642020-11-10 21:18:41 +0000208 armnn::ReshapeDescriptor reshapeDescriptor;
209 reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
210 reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor);
211 ARMNN_ASSERT(reshapeLayer != nullptr);
212
213 reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
214
215 // Connect
216 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
217 reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
Sadik Armagan6e36a642020-11-10 21:18:41 +0000218 delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
219 }
220
221 if (reshapeLayer == nullptr)
222 {
223 Connect(layer, tfLiteNode, delegateData);
224 }
225
226 auto* tfLiteNodeParameters = reinterpret_cast<TfLiteAddParams*>(tfLiteNode->builtin_data);
227 if (!tfLiteNodeParameters)
228 {
229 // No Activation
230 return kTfLiteOk;
231 }
232
233 // Check Activation
234 TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
235 return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100236}
237
Sadik Armagan6e36a642020-11-10 21:18:41 +0000238} // namespace armnnDelegate