blob: 9ce06a8d45b328541254b273c38a94b1f814ebde [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Mike Kelly04d82292023-01-19 18:29:40 +00002// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
9
Mike Kelly04d82292023-01-19 18:29:40 +000010#include "armnnUtils/TensorUtils.hpp"
Finn Williams6f9f9902020-11-13 13:23:15 +000011#include <armnn/utility/IgnoreUnused.hpp>
Sadik Armagan6e36a642020-11-10 21:18:41 +000012
Sadik Armagan62483be2020-10-23 17:14:43 +010013#include <tensorflow/lite/builtin_ops.h>
14#include <tensorflow/lite/c/builtin_op_data.h>
15#include <tensorflow/lite/c/common.h>
16#include <tensorflow/lite/minimal_logging.h>
17
18namespace armnnDelegate
19{
20
21TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
22 TfLiteContext* tfLiteContext,
23 TfLiteNode* tfLiteNode,
24 int nodeIndex,
25 int32_t operatorCode)
26{
Sadik Armagan6e36a642020-11-10 21:18:41 +000027 auto numInputs = tfLiteNode->inputs->size;
28 if (numInputs < 2)
29 {
30 TF_LITE_MAYBE_KERNEL_LOG(
31 tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
32 2, numInputs, nodeIndex);
33 return kTfLiteError;
34 }
35 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Mike Kelly84d63782022-05-06 12:14:16 +010036 bool biasEnabled = IsOptionalOperandPresent(tfLiteNode, 2);
Sadik Armagan6e36a642020-11-10 21:18:41 +000037
38 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
39 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000040 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000041 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000042 return kTfLiteError;
43 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000044
Sadik Armagan6e36a642020-11-10 21:18:41 +000045 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000046 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000047 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000048 return kTfLiteError;
49 }
50
51 const TfLiteTensor& tfLiteWeightsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000052 if (!IsValid(tfLiteContext, tfLiteWeightsTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000053 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000054 return kTfLiteError;
55 }
56
Mike Kelly84d63782022-05-06 12:14:16 +010057 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Ryan OShea4c231de2023-01-17 15:19:20 +000058 const armnn::TensorInfo& weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010059 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagan6e36a642020-11-10 21:18:41 +000060
Ryan OShea3ad2e142023-01-13 10:19:20 +000061 // Check that we support fused activation before we attempt to create a layer
62 auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams *>(tfLiteNode->builtin_data);
Ryan OShea475c7a82023-01-30 14:24:15 +000063 TfLiteFusedActivation activationType=kTfLiteActNone;
Ryan OShea3ad2e142023-01-13 10:19:20 +000064 if (tfLiteNodeParameters)
65 {
66 activationType = tfLiteNodeParameters->activation;
Ryan OShea3ad2e142023-01-13 10:19:20 +000067 TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
68 outputTensorInfo, activationType);
69 if(activationStatus != kTfLiteOk)
70 {
71 return kTfLiteError;
72 }
73 }
74
Sadik Armagan6e36a642020-11-10 21:18:41 +000075 // Fully Connected Layer accepts two dimensional weights input
76 int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions());
77 if (weightsDimension != 2)
78 {
79 TF_LITE_MAYBE_KERNEL_LOG(
80 tfLiteContext,
81 "TfLiteArmnnDelegate: Dimension #$d for Fully Connected weights is not supported by Armnn"
82 " in operator #%d node #%d: ", weightsDimension, operatorCode, nodeIndex);
83 return kTfLiteError;
84 }
85
86 armnn::TensorInfo biasTensorInfo;
87 if (biasEnabled)
88 {
89 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000090 if (!IsValid(tfLiteContext, tfLiteBiasTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000091 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000092 return kTfLiteError;
93 }
Sadik Armagan6e36a642020-11-10 21:18:41 +000094 biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
95 }
96 else
97 {
98 biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
99 }
100
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000101 armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000102 if (inputTensorInfo.GetNumDimensions() > 2)
103 {
104 // Calculate reshape to flatten to 2D [batch_size, input_size]
105 std::vector<unsigned int> reshapedDimensions(2);
106 reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
107 reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
108
109 if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
110 {
111 TF_LITE_MAYBE_KERNEL_LOG(
112 tfLiteContext,
113 "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
114 reshapedDimensions[1], operatorCode, nodeIndex);
115 return kTfLiteError;
116 }
117
118 reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
119 }
Mike Kelly04d82292023-01-19 18:29:40 +0000120 armnn::TensorInfo reshapedOutputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
121
122 if (outputTensorInfo.GetNumDimensions() > 2)
123 {
124 // Calculate reshape to flatten to 2D [batch_size, input_size]
125 std::vector<unsigned int> reshapedDimensions(2);
126 reshapedDimensions[1] = weightsTensorInfo.GetShape()[0];
127 reshapedDimensions[0] = outputTensorInfo.GetNumElements() / reshapedDimensions[1];
128
129 if (outputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
130 {
131 TF_LITE_MAYBE_KERNEL_LOG(
132 tfLiteContext,
133 "TfLiteArmnnDelegate: Failed to deduce output tensor shape from filter size #%d #%d node #%d: ",
134 reshapedDimensions[1], operatorCode, nodeIndex);
135 return kTfLiteError;
136 }
137 reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
138 }
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000139
Sadik Armagan6e36a642020-11-10 21:18:41 +0000140 armnn::FullyConnectedDescriptor descriptor;
141 descriptor.m_TransposeWeightMatrix = true;
142 descriptor.m_BiasEnabled = biasEnabled;
Ryan OShea4c231de2023-01-17 15:19:20 +0000143 descriptor.m_ConstantWeights = weightsTensorInfo.IsConstant();
Sadik Armagan6e36a642020-11-10 21:18:41 +0000144
145 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +0100146 armnn::BackendId setBackend;
Sadik Armagan6e36a642020-11-10 21:18:41 +0000147 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
148 {
Mike Kelly04d82292023-01-19 18:29:40 +0000149
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000150 FORWARD_LAYER_SUPPORT_FUNC("FULLY_CONNECTED",
Sadik Armagan6e36a642020-11-10 21:18:41 +0000151 tfLiteContext,
152 IsFullyConnectedSupported,
153 delegateData.m_Backends,
154 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100155 setBackend,
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000156 reshapedTensorInfo,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000157 outputTensorInfo,
158 weightsTensorInfo,
159 biasTensorInfo,
160 descriptor);
161 };
162
163 if (!delegateData.m_Network)
164 {
Mike Kelly04d82292023-01-19 18:29:40 +0000165 validateFunc(reshapedOutputTensorInfo, isSupported);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000166 return isSupported ? kTfLiteOk : kTfLiteError;
167 }
168
Matthew Sloyan81beae32021-07-13 19:46:11 +0100169 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100170 layer->SetBackendId(setBackend);
Matthew Sloyan81beae32021-07-13 19:46:11 +0100171 ARMNN_ASSERT(layer != nullptr);
172
173 // Add a constant layer for weights and biases if inputs are constant.
Ryan OShea4c231de2023-01-17 15:19:20 +0000174 if (weightsTensorInfo.IsConstant())
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000175 {
176 auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
Ryan OShea4c231de2023-01-17 15:19:20 +0000177 weightsTensorInfo);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000178
Matthew Sloyan81beae32021-07-13 19:46:11 +0100179 armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(weightsTensor);
180
181 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
182 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsTensorInfo);
183 }
184
185 if (biasEnabled)
186 {
187 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
Ryan OShea4c231de2023-01-17 15:19:20 +0000188 if(biasTensorInfo.IsConstant())
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000189 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000190 auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
Ryan OShea4c231de2023-01-17 15:19:20 +0000191 biasTensorInfo);
Matthew Sloyan81beae32021-07-13 19:46:11 +0100192
193 armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor);
194 ARMNN_ASSERT(biasLayer != nullptr);
195
196 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
197 biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000198 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000199 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000200
Ryan OShea4c231de2023-01-17 15:19:20 +0000201 // The data input can also be constant, so we must check that this is also allocated to an input slot
202 if(inputTensorInfo.IsConstant())
203 {
204 auto input =
205 CreateConstTensor(&tfLiteContext->tensors[tfLiteNode->inputs->data[0]],
206 inputTensorInfo);
207
208 armnn::IConnectableLayer *inputLayer = delegateData.m_Network->AddConstantLayer(input);
209 inputLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
210 inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
211 }
212
Sadik Armagan6e36a642020-11-10 21:18:41 +0000213 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
214 outputSlot.SetTensorInfo(outputTensorInfo);
215
216 armnn::IConnectableLayer* reshapeLayer = nullptr;
217 if (inputTensorInfo.GetNumDimensions() > 2)
218 {
219 // Add reshape to flatten to 2D [batch_size, input_size]
Sadik Armagan6e36a642020-11-10 21:18:41 +0000220 armnn::ReshapeDescriptor reshapeDescriptor;
221 reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
222 reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor);
223 ARMNN_ASSERT(reshapeLayer != nullptr);
224
225 reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
226
227 // Connect
228 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
229 reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
Matthew Sloyan81beae32021-07-13 19:46:11 +0100230
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000231 if (!descriptor.m_ConstantWeights)
232 {
233 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1));
Matthew Sloyan81beae32021-07-13 19:46:11 +0100234 }
235
Ryan OShea4c231de2023-01-17 15:19:20 +0000236 if (biasEnabled && !biasTensorInfo.IsConstant())
Matthew Sloyan81beae32021-07-13 19:46:11 +0100237 {
238 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000239 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000240 delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
241 }
242
243 if (reshapeLayer == nullptr)
244 {
Ryan OShea4c231de2023-01-17 15:19:20 +0000245 if(Connect(layer, tfLiteNode, delegateData) != kTfLiteOk)
246 {
247 return kTfLiteError;
248 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000249 }
Ryan OShea3ad2e142023-01-13 10:19:20 +0000250
Mike Kelly04d82292023-01-19 18:29:40 +0000251 if (outputTensorInfo.GetNumDimensions() > 2)
252 {
253 layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo,
254 delegateData);
255 if (!layer)
256 {
257 TF_LITE_MAYBE_KERNEL_LOG(
258 tfLiteContext,
259 "TfLiteArmnnDelegate: Failed to add reshape for FullyConnected #%d node #%d: ",
260 operatorCode,
261 nodeIndex);
262 return kTfLiteError;
263 }
264 }
265
Sadik Armagan6e36a642020-11-10 21:18:41 +0000266 if (!tfLiteNodeParameters)
267 {
268 // No Activation
269 return kTfLiteOk;
270 }
Ryan OShea3ad2e142023-01-13 10:19:20 +0000271
272 // Check and Create Activation
Sadik Armagan6e36a642020-11-10 21:18:41 +0000273 return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100274}
275
Sadik Armagan6e36a642020-11-10 21:18:41 +0000276} // namespace armnnDelegate