blob: 2d4e9879423451da96a4ba9aa4e112e0d9c54915 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Mike Kelly04d82292023-01-19 18:29:40 +00002// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
9
Mike Kelly04d82292023-01-19 18:29:40 +000010#include "armnnUtils/TensorUtils.hpp"
Finn Williams6f9f9902020-11-13 13:23:15 +000011#include <armnn/utility/IgnoreUnused.hpp>
Sadik Armagan6e36a642020-11-10 21:18:41 +000012
Sadik Armagan62483be2020-10-23 17:14:43 +010013#include <tensorflow/lite/builtin_ops.h>
14#include <tensorflow/lite/c/builtin_op_data.h>
15#include <tensorflow/lite/c/common.h>
16#include <tensorflow/lite/minimal_logging.h>
17
18namespace armnnDelegate
19{
20
21TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
22 TfLiteContext* tfLiteContext,
23 TfLiteNode* tfLiteNode,
24 int nodeIndex,
25 int32_t operatorCode)
26{
Sadik Armagan6e36a642020-11-10 21:18:41 +000027 auto numInputs = tfLiteNode->inputs->size;
28 if (numInputs < 2)
29 {
30 TF_LITE_MAYBE_KERNEL_LOG(
31 tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
32 2, numInputs, nodeIndex);
33 return kTfLiteError;
34 }
35 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Mike Kelly84d63782022-05-06 12:14:16 +010036 bool biasEnabled = IsOptionalOperandPresent(tfLiteNode, 2);
Sadik Armagan6e36a642020-11-10 21:18:41 +000037
38 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
39 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000040 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000041 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000042 return kTfLiteError;
43 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000044
Sadik Armagan6e36a642020-11-10 21:18:41 +000045 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000046 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000047 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000048 return kTfLiteError;
49 }
50
51 const TfLiteTensor& tfLiteWeightsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000052 if (!IsValid(tfLiteContext, tfLiteWeightsTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000053 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000054 return kTfLiteError;
55 }
56
Mike Kelly84d63782022-05-06 12:14:16 +010057 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Ryan OShea4c231de2023-01-17 15:19:20 +000058 const armnn::TensorInfo& weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010059 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagan6e36a642020-11-10 21:18:41 +000060
Ryan OShea3ad2e142023-01-13 10:19:20 +000061 // Check that we support fused activation before we attempt to create a layer
62 auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams *>(tfLiteNode->builtin_data);
Ryan OShea475c7a82023-01-30 14:24:15 +000063 TfLiteFusedActivation activationType=kTfLiteActNone;
Ryan OShea3ad2e142023-01-13 10:19:20 +000064 if (tfLiteNodeParameters)
65 {
66 activationType = tfLiteNodeParameters->activation;
Ryan OShea3ad2e142023-01-13 10:19:20 +000067 TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
68 outputTensorInfo, activationType);
69 if(activationStatus != kTfLiteOk)
70 {
71 return kTfLiteError;
72 }
73 }
74
Sadik Armagan6e36a642020-11-10 21:18:41 +000075 // Fully Connected Layer accepts two dimensional weights input
76 int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions());
77 if (weightsDimension != 2)
78 {
79 TF_LITE_MAYBE_KERNEL_LOG(
80 tfLiteContext,
81 "TfLiteArmnnDelegate: Dimension #$d for Fully Connected weights is not supported by Armnn"
82 " in operator #%d node #%d: ", weightsDimension, operatorCode, nodeIndex);
83 return kTfLiteError;
84 }
85
86 armnn::TensorInfo biasTensorInfo;
87 if (biasEnabled)
88 {
89 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000090 if (!IsValid(tfLiteContext, tfLiteBiasTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000091 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000092 return kTfLiteError;
93 }
Sadik Armagan6e36a642020-11-10 21:18:41 +000094 biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
95 }
96 else
97 {
98 biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
99 }
100
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000101 armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000102 if (inputTensorInfo.GetNumDimensions() > 2)
103 {
104 // Calculate reshape to flatten to 2D [batch_size, input_size]
105 std::vector<unsigned int> reshapedDimensions(2);
106 reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
107 reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
108
109 if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
110 {
111 TF_LITE_MAYBE_KERNEL_LOG(
112 tfLiteContext,
113 "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
114 reshapedDimensions[1], operatorCode, nodeIndex);
115 return kTfLiteError;
116 }
117
118 reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
119 }
Mike Kelly04d82292023-01-19 18:29:40 +0000120 armnn::TensorInfo reshapedOutputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
121
122 if (outputTensorInfo.GetNumDimensions() > 2)
123 {
124 // Calculate reshape to flatten to 2D [batch_size, input_size]
125 std::vector<unsigned int> reshapedDimensions(2);
126 reshapedDimensions[1] = weightsTensorInfo.GetShape()[0];
127 reshapedDimensions[0] = outputTensorInfo.GetNumElements() / reshapedDimensions[1];
128
129 if (outputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
130 {
131 TF_LITE_MAYBE_KERNEL_LOG(
132 tfLiteContext,
133 "TfLiteArmnnDelegate: Failed to deduce output tensor shape from filter size #%d #%d node #%d: ",
134 reshapedDimensions[1], operatorCode, nodeIndex);
135 return kTfLiteError;
136 }
137 reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
138 }
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000139
Sadik Armagan6e36a642020-11-10 21:18:41 +0000140 armnn::FullyConnectedDescriptor descriptor;
141 descriptor.m_TransposeWeightMatrix = true;
142 descriptor.m_BiasEnabled = biasEnabled;
Ryan OShea4c231de2023-01-17 15:19:20 +0000143 descriptor.m_ConstantWeights = weightsTensorInfo.IsConstant();
Sadik Armagan6e36a642020-11-10 21:18:41 +0000144
145 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +0100146 armnn::BackendId setBackend;
Sadik Armagan6e36a642020-11-10 21:18:41 +0000147 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
148 {
Mike Kelly04d82292023-01-19 18:29:40 +0000149
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000150 FORWARD_LAYER_SUPPORT_FUNC("FULLY_CONNECTED",
Sadik Armagan6e36a642020-11-10 21:18:41 +0000151 tfLiteContext,
152 IsFullyConnectedSupported,
153 delegateData.m_Backends,
154 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100155 setBackend,
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000156 reshapedTensorInfo,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000157 outputTensorInfo,
158 weightsTensorInfo,
159 biasTensorInfo,
160 descriptor);
161 };
162
163 if (!delegateData.m_Network)
164 {
Mike Kelly04d82292023-01-19 18:29:40 +0000165 validateFunc(reshapedOutputTensorInfo, isSupported);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000166 return isSupported ? kTfLiteOk : kTfLiteError;
167 }
168
Mike Kelly07169c82023-08-02 13:23:09 +0100169 auto layerName = GetLayerName(armnn::LayerType::FullyConnected, nodeIndex);
170 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor, layerName.c_str());
Cathal Corbett53837672022-09-01 11:34:37 +0100171 layer->SetBackendId(setBackend);
Matthew Sloyan81beae32021-07-13 19:46:11 +0100172 ARMNN_ASSERT(layer != nullptr);
173
174 // Add a constant layer for weights and biases if inputs are constant.
Ryan OShea4c231de2023-01-17 15:19:20 +0000175 if (weightsTensorInfo.IsConstant())
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000176 {
177 auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
Ryan OShea4c231de2023-01-17 15:19:20 +0000178 weightsTensorInfo);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000179
Mike Kelly07169c82023-08-02 13:23:09 +0100180 auto weightsName = GetLayerName(armnn::LayerType::Constant, nodeIndex, "Weights");
181 armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(weightsTensor,
182 weightsName.c_str());
Matthew Sloyan81beae32021-07-13 19:46:11 +0100183
184 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
185 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsTensorInfo);
186 }
187
188 if (biasEnabled)
189 {
190 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
Ryan OShea4c231de2023-01-17 15:19:20 +0000191 if(biasTensorInfo.IsConstant())
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000192 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000193 auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
Ryan OShea4c231de2023-01-17 15:19:20 +0000194 biasTensorInfo);
Matthew Sloyan81beae32021-07-13 19:46:11 +0100195
Mike Kelly07169c82023-08-02 13:23:09 +0100196 auto biasName = GetLayerName(armnn::LayerType::FullyConnected, nodeIndex, "Bias");
197 armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor,
198 biasName.c_str());
Matthew Sloyan81beae32021-07-13 19:46:11 +0100199 ARMNN_ASSERT(biasLayer != nullptr);
200
201 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
202 biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000203 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000204 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000205
Ryan OShea4c231de2023-01-17 15:19:20 +0000206 // The data input can also be constant, so we must check that this is also allocated to an input slot
Mike Kelly07169c82023-08-02 13:23:09 +0100207 if (inputTensorInfo.IsConstant())
Ryan OShea4c231de2023-01-17 15:19:20 +0000208 {
209 auto input =
210 CreateConstTensor(&tfLiteContext->tensors[tfLiteNode->inputs->data[0]],
211 inputTensorInfo);
212
Mike Kelly07169c82023-08-02 13:23:09 +0100213 auto constantName = GetLayerName(armnn::LayerType::Constant, nodeIndex, "Input");
214 armnn::IConnectableLayer *inputLayer = delegateData.m_Network->AddConstantLayer(input, constantName.c_str());
Ryan OShea4c231de2023-01-17 15:19:20 +0000215 inputLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
216 inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
217 }
218
Sadik Armagan6e36a642020-11-10 21:18:41 +0000219 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
220 outputSlot.SetTensorInfo(outputTensorInfo);
221
222 armnn::IConnectableLayer* reshapeLayer = nullptr;
223 if (inputTensorInfo.GetNumDimensions() > 2)
224 {
225 // Add reshape to flatten to 2D [batch_size, input_size]
Sadik Armagan6e36a642020-11-10 21:18:41 +0000226 armnn::ReshapeDescriptor reshapeDescriptor;
227 reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
Mike Kelly07169c82023-08-02 13:23:09 +0100228
229 auto reshapeName = GetLayerName(armnn::LayerType::Reshape, nodeIndex, "Input");
230 reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor, reshapeName.c_str());
Sadik Armagan6e36a642020-11-10 21:18:41 +0000231 ARMNN_ASSERT(reshapeLayer != nullptr);
232
233 reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
234
235 // Connect
236 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
237 reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
Matthew Sloyan81beae32021-07-13 19:46:11 +0100238
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000239 if (!descriptor.m_ConstantWeights)
240 {
241 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1));
Matthew Sloyan81beae32021-07-13 19:46:11 +0100242 }
243
Ryan OShea4c231de2023-01-17 15:19:20 +0000244 if (biasEnabled && !biasTensorInfo.IsConstant())
Matthew Sloyan81beae32021-07-13 19:46:11 +0100245 {
246 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000247 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000248 delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
249 }
250
251 if (reshapeLayer == nullptr)
252 {
Ryan OShea4c231de2023-01-17 15:19:20 +0000253 if(Connect(layer, tfLiteNode, delegateData) != kTfLiteOk)
254 {
255 return kTfLiteError;
256 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000257 }
Ryan OShea3ad2e142023-01-13 10:19:20 +0000258
Mike Kelly04d82292023-01-19 18:29:40 +0000259 if (outputTensorInfo.GetNumDimensions() > 2)
260 {
261 layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo,
Mike Kelly07169c82023-08-02 13:23:09 +0100262 delegateData, nodeIndex);
Mike Kelly04d82292023-01-19 18:29:40 +0000263 if (!layer)
264 {
265 TF_LITE_MAYBE_KERNEL_LOG(
266 tfLiteContext,
267 "TfLiteArmnnDelegate: Failed to add reshape for FullyConnected #%d node #%d: ",
268 operatorCode,
269 nodeIndex);
270 return kTfLiteError;
271 }
272 }
273
Sadik Armagan6e36a642020-11-10 21:18:41 +0000274 if (!tfLiteNodeParameters)
275 {
276 // No Activation
277 return kTfLiteOk;
278 }
Ryan OShea3ad2e142023-01-13 10:19:20 +0000279
280 // Check and Create Activation
Mike Kelly07169c82023-08-02 13:23:09 +0100281 return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData, nodeIndex);
Sadik Armagan62483be2020-10-23 17:14:43 +0100282}
283
Sadik Armagan6e36a642020-11-10 21:18:41 +0000284} // namespace armnnDelegate