blob: 48bf06f94a49ac02e277d1efefc03b5f61e888aa [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagan6e36a642020-11-10 21:18:41 +00008#include "DelegateUtils.hpp"
9
Sadik Armagan62483be2020-10-23 17:14:43 +010010#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
18TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 TfLiteNode* tfLiteNode,
21 int nodeIndex,
22 int32_t operatorCode)
23{
Sadik Armagan6e36a642020-11-10 21:18:41 +000024 auto numInputs = tfLiteNode->inputs->size;
25 if (numInputs < 2)
26 {
27 TF_LITE_MAYBE_KERNEL_LOG(
28 tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
29 2, numInputs, nodeIndex);
30 return kTfLiteError;
31 }
32 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
33 bool biasEnabled = (numInputs == 3);
34
35 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
36 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
37 if(!IsValid(&tfLiteTensors[tfLiteNode->inputs->data[0]]))
38 {
39 TF_LITE_MAYBE_KERNEL_LOG(
40 tfLiteContext,
41 "TfLiteArmnnDelegate: Invalid input tensor in operator #%d node #%d: ",
42 operatorCode, nodeIndex);
43 return kTfLiteError;
44 }
45 if (IsDynamicTensor(tfLiteInputTensor))
46 {
47 TF_LITE_MAYBE_KERNEL_LOG(
48 tfLiteContext,
49 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in node #%d: ",
50 nodeIndex);
51 return kTfLiteError;
52 }
53 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
54 if(!IsValid(&tfLiteOutputTensor))
55 {
56 TF_LITE_MAYBE_KERNEL_LOG(
57 tfLiteContext,
58 "TfLiteArmnnDelegate: Invalid output tensor in operator #%d node #%d: ",
59 operatorCode, nodeIndex);
60 return kTfLiteError;
61 }
62 if (IsDynamicTensor(tfLiteOutputTensor))
63 {
64 TF_LITE_MAYBE_KERNEL_LOG(
65 tfLiteContext,
66 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in node #%d: ",
67 nodeIndex);
68 return kTfLiteError;
69 }
70
71 const TfLiteTensor& tfLiteWeightsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
72 if(!IsValid(&tfLiteWeightsTensor))
73 {
74 TF_LITE_MAYBE_KERNEL_LOG(
75 tfLiteContext,
76 "TfLiteArmnnDelegate: Invalid weights tensor in operator #%d node #%d: ",
77 operatorCode, nodeIndex);
78 return kTfLiteError;
79 }
80 if (IsDynamicTensor(tfLiteWeightsTensor))
81 {
82 TF_LITE_MAYBE_KERNEL_LOG(
83 tfLiteContext,
84 "TfLiteArmnnDelegate: Dynamic weight tensors are not supported in node #%d: ",
85 nodeIndex);
86 return kTfLiteError;
87 }
88
89 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
90 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
91
92 armnn::TensorInfo weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
93 // Fully Connected Layer accepts two dimensional weights input
94 int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions());
95 if (weightsDimension != 2)
96 {
97 TF_LITE_MAYBE_KERNEL_LOG(
98 tfLiteContext,
99 "TfLiteArmnnDelegate: Dimension #$d for Fully Connected weights is not supported by Armnn"
100 " in operator #%d node #%d: ", weightsDimension, operatorCode, nodeIndex);
101 return kTfLiteError;
102 }
103
104 armnn::TensorInfo biasTensorInfo;
105 if (biasEnabled)
106 {
107 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
108 if(!IsValid(&tfLiteBiasTensor))
109 {
110 TF_LITE_MAYBE_KERNEL_LOG(
111 tfLiteContext,
112 "TfLiteArmnnDelegate: Invalid bias tensor in operator #%d node #%d: ",
113 operatorCode, nodeIndex);
114 return kTfLiteError;
115 }
116 if (IsDynamicTensor(tfLiteBiasTensor))
117 {
118 TF_LITE_MAYBE_KERNEL_LOG(
119 tfLiteContext,
120 "TfLiteArmnnDelegate: Dynamic bias tensors are not supported in node #%d: ",
121 nodeIndex);
122 return kTfLiteError;
123 }
124 biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
125 }
126 else
127 {
128 biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
129 }
130
131 armnn::FullyConnectedDescriptor descriptor;
132 descriptor.m_TransposeWeightMatrix = true;
133 descriptor.m_BiasEnabled = biasEnabled;
134
135 bool isSupported = false;
136 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
137 {
138 FORWARD_LAYER_SUPPORT_FUNC(__func__,
139 tfLiteContext,
140 IsFullyConnectedSupported,
141 delegateData.m_Backends,
142 isSupported,
143 inputTensorInfo,
144 outputTensorInfo,
145 weightsTensorInfo,
146 biasTensorInfo,
147 descriptor);
148 };
149
150 if (!delegateData.m_Network)
151 {
152 validateFunc(outputTensorInfo, isSupported);
153 return isSupported ? kTfLiteOk : kTfLiteError;
154 }
155
156 auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
157 weightsTensorInfo,
158 armnn::Optional<armnn::PermutationVector&>());
159
160 armnn::IConnectableLayer* layer = nullptr;
161 if (biasEnabled)
162 {
163 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
164 auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
165 biasTensorInfo,
166 armnn::Optional<armnn::PermutationVector&>());
167 layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor,
Sadik Armagan4189cc52020-11-11 18:01:48 +0000168 weightsTensor,
169 armnn::Optional<armnn::ConstTensor>(biasTensor));
Sadik Armagan6e36a642020-11-10 21:18:41 +0000170 }
171 else
172 {
173 layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor,
Sadik Armagan4189cc52020-11-11 18:01:48 +0000174 weightsTensor,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000175 armnn::EmptyOptional());
176 }
177 ARMNN_ASSERT(layer != nullptr);
178
179 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
180 outputSlot.SetTensorInfo(outputTensorInfo);
181
182 armnn::IConnectableLayer* reshapeLayer = nullptr;
183 if (inputTensorInfo.GetNumDimensions() > 2)
184 {
185 // Add reshape to flatten to 2D [batch_size, input_size]
186 std::vector<unsigned int> reshapedDimensions(2);
187 reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
188 reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
189
190 if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
191 {
192 TF_LITE_MAYBE_KERNEL_LOG(
193 tfLiteContext,
194 "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
195 reshapedDimensions[1], operatorCode, nodeIndex);
196 return kTfLiteError;
197 }
198
199 armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
200 reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
201
202 armnn::ReshapeDescriptor reshapeDescriptor;
203 reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
204 reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor);
205 ARMNN_ASSERT(reshapeLayer != nullptr);
206
207 reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
208
209 // Connect
210 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
211 reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
212 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
213 delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
214 }
215
216 if (reshapeLayer == nullptr)
217 {
218 Connect(layer, tfLiteNode, delegateData);
219 }
220
221 auto* tfLiteNodeParameters = reinterpret_cast<TfLiteAddParams*>(tfLiteNode->builtin_data);
222 if (!tfLiteNodeParameters)
223 {
224 // No Activation
225 return kTfLiteOk;
226 }
227
228 // Check Activation
229 TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
230 return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100231}
232
Sadik Armagan6e36a642020-11-10 21:18:41 +0000233} // namespace armnnDelegate