blob: 3282cab543d70b7537934e2e3ea8041cab7e0343 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Matthew Sloyan0bd4c622023-04-27 11:48:26 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9#include <SharedFunctions.hpp>
10
11namespace armnnOpaqueDelegate
12{
13
14TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
15 TfLiteOpaqueContext* tfLiteContext,
16 TfLiteOpaqueNode* tfLiteNode,
17 int nodeIndex,
18 int32_t operatorCode)
19{
20 auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
21 if (numInputs < 2)
22 {
23 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
24 tfLiteContext,
25 "TfLiteArmnnOpaqueDelegate: Minimum number of inputs (%d != %d) in node #%d",
26 2, numInputs, nodeIndex);
27 return kTfLiteError;
28 }
29 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
30
31 // Gather input indices and use to get input tensor.
32 const int* inputTensors;
33 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
34 {
35 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
36 tfLiteContext,
37 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
38 nodeIndex);
39 return kTfLiteError;
40 }
41
42 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
43 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
44 {
45 return kTfLiteError;
46 }
47
48 const TfLiteOpaqueTensor* tfLiteWeightsTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
49 if (!IsValid(tfLiteContext, tfLiteWeightsTensor, operatorCode, nodeIndex))
50 {
51 return kTfLiteError;
52 }
53
54 // Gather output indices and use to get output tensors.
55 int numOutputs = 0;
56 const int* outputTensors;
57 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
58 {
59 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
60 tfLiteContext,
61 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
62 nodeIndex);
63 return kTfLiteError;
64 }
65
66 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
67 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
68 {
69 return kTfLiteError;
70 }
71
72 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
73 const armnn::TensorInfo& weightsTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteWeightsTensor);
74 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
75
76 // Check that we support fused activation before we attempt to create a layer
77 auto* tfLiteNodeParameters =
78 reinterpret_cast<TfLiteFullyConnectedParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
79 TfLiteFusedActivation activationType=kTfLiteActNone;
80 if (tfLiteNodeParameters)
81 {
82 activationType = tfLiteNodeParameters->activation;
83 TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
84 outputTensorInfo, activationType);
85 if(activationStatus != kTfLiteOk)
86 {
87 return kTfLiteError;
88 }
89 }
90
91 // Fully Connected Layer accepts two dimensional weights input
92 int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions());
93 if (weightsDimension != 2)
94 {
95 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
96 tfLiteContext,
97 "TfLiteArmnnOpaqueDelegate: Dimension #$d for Fully Connected weights is not supported by Armnn"
98 " in operator #%d node #%d: ", weightsDimension, operatorCode, nodeIndex);
99 return kTfLiteError;
100 }
101
102 armnn::TensorInfo biasTensorInfo;
103 const TfLiteOpaqueTensor* tfLiteBiasTensor = nullptr;
104
105 bool biasEnabled = IsOptionalOperandPresent(tfLiteNode, 2);
106 if (biasEnabled)
107 {
108 // Use input indices to get bias tensor.
109 tfLiteBiasTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[2]);
110 if (!IsValid(tfLiteContext, tfLiteBiasTensor, operatorCode, nodeIndex))
111 {
112 return kTfLiteError;
113 }
114 biasTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteBiasTensor);
115 }
116 else
117 {
118 biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
119 }
120
121 armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
122 if (inputTensorInfo.GetNumDimensions() > 2)
123 {
124 // Calculate reshape to flatten to 2D [batch_size, input_size]
125 std::vector<unsigned int> reshapedDimensions(2);
126 reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
127 reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
128
129 if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
130 {
131 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
132 tfLiteContext,
133 "TfLiteArmnnOpaqueDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
134 reshapedDimensions[1], operatorCode, nodeIndex);
135 return kTfLiteError;
136 }
137
138 reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
139 }
140 armnn::TensorInfo reshapedOutputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor);
141
142 if (outputTensorInfo.GetNumDimensions() > 2)
143 {
144 // Calculate reshape to flatten to 2D [batch_size, input_size]
145 std::vector<unsigned int> reshapedDimensions(2);
146 reshapedDimensions[1] = weightsTensorInfo.GetShape()[0];
147 reshapedDimensions[0] = outputTensorInfo.GetNumElements() / reshapedDimensions[1];
148
149 if (outputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
150 {
151 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
152 tfLiteContext,
153 "TfLiteArmnnOpaqueDelegate: Failed to deduce output tensor shape from filter size #%d #%d node #%d: ",
154 reshapedDimensions[1], operatorCode, nodeIndex);
155 return kTfLiteError;
156 }
157 reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
158 }
159
160 armnn::FullyConnectedDescriptor descriptor;
161 descriptor.m_TransposeWeightMatrix = true;
162 descriptor.m_BiasEnabled = biasEnabled;
163 descriptor.m_ConstantWeights = weightsTensorInfo.IsConstant();
164
165 bool isSupported = false;
166 armnn::BackendId setBackend;
167 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
168 {
169
170 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("FULLY_CONNECTED",
171 tfLiteContext,
172 IsFullyConnectedSupported,
173 delegateData.m_Backends,
174 isSupported,
175 setBackend,
176 reshapedTensorInfo,
177 outputTensorInfo,
178 weightsTensorInfo,
179 biasTensorInfo,
180 descriptor);
181 };
182
183 if (!delegateData.m_Network)
184 {
185 validateFunc(reshapedOutputTensorInfo, isSupported);
186 return isSupported ? kTfLiteOk : kTfLiteError;
187 }
188
189 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor);
190 layer->SetBackendId(setBackend);
191 ARMNN_ASSERT(layer != nullptr);
192
193 // Add a constant layer for weights and biases if inputs are constant.
194 if (weightsTensorInfo.IsConstant())
195 {
196 auto weightsTensor = CreateConstTensor(tfLiteWeightsTensor, weightsTensorInfo);
197
198 armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(weightsTensor);
199
200 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
201 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsTensorInfo);
202 }
203
204 if (biasEnabled)
205 {
206 if(biasTensorInfo.IsConstant())
207 {
208 auto biasTensor = CreateConstTensor(tfLiteBiasTensor, biasTensorInfo);
209
210 armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor);
211 ARMNN_ASSERT(biasLayer != nullptr);
212
213 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
214 biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo);
215 }
216 }
217
218 // The data input can also be constant, so we must check that this is also allocated to an input slot
219 if(inputTensorInfo.IsConstant())
220 {
221 auto input = CreateConstTensor(tfLiteInputTensor, inputTensorInfo);
222
223 armnn::IConnectableLayer* inputLayer = delegateData.m_Network->AddConstantLayer(input);
224 inputLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
225 inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
226 }
227
228 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
229 outputSlot.SetTensorInfo(outputTensorInfo);
230
231 armnn::IConnectableLayer* reshapeLayer = nullptr;
232 if (inputTensorInfo.GetNumDimensions() > 2)
233 {
234 // Add reshape to flatten to 2D [batch_size, input_size]
235 armnn::ReshapeDescriptor reshapeDescriptor;
236 reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
237 reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor);
238 ARMNN_ASSERT(reshapeLayer != nullptr);
239
240 reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
241
242 // Connect
243 delegateData.m_OutputSlotForNode[inputTensors[0]]->Connect(reshapeLayer->GetInputSlot(0));
244 reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
245
246 if (!descriptor.m_ConstantWeights)
247 {
248 delegateData.m_OutputSlotForNode[inputTensors[1]]->Connect(layer->GetInputSlot(1));
249 }
250
251 if (biasEnabled && !biasTensorInfo.IsConstant())
252 {
253 delegateData.m_OutputSlotForNode[inputTensors[2]]->Connect(layer->GetInputSlot(2));
254 }
255 delegateData.m_OutputSlotForNode[outputTensors[0]] = &outputSlot;
256 }
257
258 if (reshapeLayer == nullptr)
259 {
260 if(Connect(layer, tfLiteContext, tfLiteNode, delegateData) != kTfLiteOk)
261 {
262 return kTfLiteError;
263 }
264 }
265
266 if (outputTensorInfo.GetNumDimensions() > 2)
267 {
268 layer = AddReshapeLayer(tfLiteContext,
269 tfLiteNode,
270 layer,
271 reshapedOutputTensorInfo,
272 outputTensorInfo,
273 delegateData);
274 if (!layer)
275 {
276 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
277 tfLiteContext,
278 "TfLiteArmnnOpaqueDelegate: Failed to add reshape for FullyConnected #%d node #%d: ",
279 operatorCode,
280 nodeIndex);
281 return kTfLiteError;
282 }
283 }
284
285 if (!tfLiteNodeParameters)
286 {
287 // No Activation
288 return kTfLiteOk;
289 }
290
291 // Check and Create Activation
292 return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
293}
294
295} // namespace armnnOpaqueDelegate