blob: 1129951104e7c75636969d1a971484590fdf3d7d [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Mike Kelly04d82292023-01-19 18:29:40 +00002// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagan6e36a642020-11-10 21:18:41 +00008#include "DelegateUtils.hpp"
Mike Kelly04d82292023-01-19 18:29:40 +00009#include "armnnUtils/TensorUtils.hpp"
Finn Williams6f9f9902020-11-13 13:23:15 +000010#include <armnn/utility/IgnoreUnused.hpp>
Sadik Armagan6e36a642020-11-10 21:18:41 +000011
Sadik Armagan62483be2020-10-23 17:14:43 +010012#include <tensorflow/lite/builtin_ops.h>
13#include <tensorflow/lite/c/builtin_op_data.h>
14#include <tensorflow/lite/c/common.h>
15#include <tensorflow/lite/minimal_logging.h>
16
17namespace armnnDelegate
18{
19
20TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
21 TfLiteContext* tfLiteContext,
22 TfLiteNode* tfLiteNode,
23 int nodeIndex,
24 int32_t operatorCode)
25{
Sadik Armagan6e36a642020-11-10 21:18:41 +000026 auto numInputs = tfLiteNode->inputs->size;
27 if (numInputs < 2)
28 {
29 TF_LITE_MAYBE_KERNEL_LOG(
30 tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
31 2, numInputs, nodeIndex);
32 return kTfLiteError;
33 }
34 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Mike Kelly84d63782022-05-06 12:14:16 +010035 bool biasEnabled = IsOptionalOperandPresent(tfLiteNode, 2);
Sadik Armagan6e36a642020-11-10 21:18:41 +000036
37 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
38 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000039 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000040 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000041 return kTfLiteError;
42 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000043
Sadik Armagan6e36a642020-11-10 21:18:41 +000044 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000045 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000046 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000047 return kTfLiteError;
48 }
49
50 const TfLiteTensor& tfLiteWeightsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000051 if (!IsValid(tfLiteContext, tfLiteWeightsTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000052 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000053 return kTfLiteError;
54 }
55
Mike Kelly84d63782022-05-06 12:14:16 +010056 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Ryan OShea4c231de2023-01-17 15:19:20 +000057 const armnn::TensorInfo& weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010058 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagan6e36a642020-11-10 21:18:41 +000059
Ryan OShea3ad2e142023-01-13 10:19:20 +000060 // Check that we support fused activation before we attempt to create a layer
61 auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams *>(tfLiteNode->builtin_data);
Ryan OShea475c7a82023-01-30 14:24:15 +000062 TfLiteFusedActivation activationType=kTfLiteActNone;
Ryan OShea3ad2e142023-01-13 10:19:20 +000063 if (tfLiteNodeParameters)
64 {
65 activationType = tfLiteNodeParameters->activation;
Ryan OShea3ad2e142023-01-13 10:19:20 +000066 TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
67 outputTensorInfo, activationType);
68 if(activationStatus != kTfLiteOk)
69 {
70 return kTfLiteError;
71 }
72 }
73
Sadik Armagan6e36a642020-11-10 21:18:41 +000074 // Fully Connected Layer accepts two dimensional weights input
75 int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions());
76 if (weightsDimension != 2)
77 {
78 TF_LITE_MAYBE_KERNEL_LOG(
79 tfLiteContext,
80 "TfLiteArmnnDelegate: Dimension #$d for Fully Connected weights is not supported by Armnn"
81 " in operator #%d node #%d: ", weightsDimension, operatorCode, nodeIndex);
82 return kTfLiteError;
83 }
84
85 armnn::TensorInfo biasTensorInfo;
86 if (biasEnabled)
87 {
88 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000089 if (!IsValid(tfLiteContext, tfLiteBiasTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000090 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000091 return kTfLiteError;
92 }
Sadik Armagan6e36a642020-11-10 21:18:41 +000093 biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
94 }
95 else
96 {
97 biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
98 }
99
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000100 armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000101 if (inputTensorInfo.GetNumDimensions() > 2)
102 {
103 // Calculate reshape to flatten to 2D [batch_size, input_size]
104 std::vector<unsigned int> reshapedDimensions(2);
105 reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
106 reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
107
108 if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
109 {
110 TF_LITE_MAYBE_KERNEL_LOG(
111 tfLiteContext,
112 "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
113 reshapedDimensions[1], operatorCode, nodeIndex);
114 return kTfLiteError;
115 }
116
117 reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
118 }
Mike Kelly04d82292023-01-19 18:29:40 +0000119 armnn::TensorInfo reshapedOutputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
120
121 if (outputTensorInfo.GetNumDimensions() > 2)
122 {
123 // Calculate reshape to flatten to 2D [batch_size, input_size]
124 std::vector<unsigned int> reshapedDimensions(2);
125 reshapedDimensions[1] = weightsTensorInfo.GetShape()[0];
126 reshapedDimensions[0] = outputTensorInfo.GetNumElements() / reshapedDimensions[1];
127
128 if (outputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
129 {
130 TF_LITE_MAYBE_KERNEL_LOG(
131 tfLiteContext,
132 "TfLiteArmnnDelegate: Failed to deduce output tensor shape from filter size #%d #%d node #%d: ",
133 reshapedDimensions[1], operatorCode, nodeIndex);
134 return kTfLiteError;
135 }
136 reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
137 }
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000138
Sadik Armagan6e36a642020-11-10 21:18:41 +0000139 armnn::FullyConnectedDescriptor descriptor;
140 descriptor.m_TransposeWeightMatrix = true;
141 descriptor.m_BiasEnabled = biasEnabled;
Ryan OShea4c231de2023-01-17 15:19:20 +0000142 descriptor.m_ConstantWeights = weightsTensorInfo.IsConstant();
Sadik Armagan6e36a642020-11-10 21:18:41 +0000143
144 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +0100145 armnn::BackendId setBackend;
Sadik Armagan6e36a642020-11-10 21:18:41 +0000146 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
147 {
Mike Kelly04d82292023-01-19 18:29:40 +0000148
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000149 FORWARD_LAYER_SUPPORT_FUNC("FULLY_CONNECTED",
Sadik Armagan6e36a642020-11-10 21:18:41 +0000150 tfLiteContext,
151 IsFullyConnectedSupported,
152 delegateData.m_Backends,
153 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100154 setBackend,
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000155 reshapedTensorInfo,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000156 outputTensorInfo,
157 weightsTensorInfo,
158 biasTensorInfo,
159 descriptor);
160 };
161
162 if (!delegateData.m_Network)
163 {
Mike Kelly04d82292023-01-19 18:29:40 +0000164 validateFunc(reshapedOutputTensorInfo, isSupported);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000165 return isSupported ? kTfLiteOk : kTfLiteError;
166 }
167
Matthew Sloyan81beae32021-07-13 19:46:11 +0100168 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100169 layer->SetBackendId(setBackend);
Matthew Sloyan81beae32021-07-13 19:46:11 +0100170 ARMNN_ASSERT(layer != nullptr);
171
172 // Add a constant layer for weights and biases if inputs are constant.
Ryan OShea4c231de2023-01-17 15:19:20 +0000173 if (weightsTensorInfo.IsConstant())
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000174 {
175 auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
Ryan OShea4c231de2023-01-17 15:19:20 +0000176 weightsTensorInfo);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000177
Matthew Sloyan81beae32021-07-13 19:46:11 +0100178 armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(weightsTensor);
179
180 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
181 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsTensorInfo);
182 }
183
184 if (biasEnabled)
185 {
186 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
Ryan OShea4c231de2023-01-17 15:19:20 +0000187 if(biasTensorInfo.IsConstant())
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000188 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000189 auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
Ryan OShea4c231de2023-01-17 15:19:20 +0000190 biasTensorInfo);
Matthew Sloyan81beae32021-07-13 19:46:11 +0100191
192 armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor);
193 ARMNN_ASSERT(biasLayer != nullptr);
194
195 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
196 biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000197 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000198 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000199
Ryan OShea4c231de2023-01-17 15:19:20 +0000200 // The data input can also be constant, so we must check that this is also allocated to an input slot
201 if(inputTensorInfo.IsConstant())
202 {
203 auto input =
204 CreateConstTensor(&tfLiteContext->tensors[tfLiteNode->inputs->data[0]],
205 inputTensorInfo);
206
207 armnn::IConnectableLayer *inputLayer = delegateData.m_Network->AddConstantLayer(input);
208 inputLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
209 inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
210 }
211
Sadik Armagan6e36a642020-11-10 21:18:41 +0000212 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
213 outputSlot.SetTensorInfo(outputTensorInfo);
214
215 armnn::IConnectableLayer* reshapeLayer = nullptr;
216 if (inputTensorInfo.GetNumDimensions() > 2)
217 {
218 // Add reshape to flatten to 2D [batch_size, input_size]
Sadik Armagan6e36a642020-11-10 21:18:41 +0000219 armnn::ReshapeDescriptor reshapeDescriptor;
220 reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
221 reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor);
222 ARMNN_ASSERT(reshapeLayer != nullptr);
223
224 reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
225
226 // Connect
227 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
228 reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
Matthew Sloyan81beae32021-07-13 19:46:11 +0100229
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000230 if (!descriptor.m_ConstantWeights)
231 {
232 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1));
Matthew Sloyan81beae32021-07-13 19:46:11 +0100233 }
234
Ryan OShea4c231de2023-01-17 15:19:20 +0000235 if (biasEnabled && !biasTensorInfo.IsConstant())
Matthew Sloyan81beae32021-07-13 19:46:11 +0100236 {
237 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000238 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000239 delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
240 }
241
242 if (reshapeLayer == nullptr)
243 {
Ryan OShea4c231de2023-01-17 15:19:20 +0000244 if(Connect(layer, tfLiteNode, delegateData) != kTfLiteOk)
245 {
246 return kTfLiteError;
247 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000248 }
Ryan OShea3ad2e142023-01-13 10:19:20 +0000249
Mike Kelly04d82292023-01-19 18:29:40 +0000250 if (outputTensorInfo.GetNumDimensions() > 2)
251 {
252 layer = AddReshapeLayer(tfLiteContext, tfLiteNode, layer, reshapedOutputTensorInfo, outputTensorInfo,
253 delegateData);
254 if (!layer)
255 {
256 TF_LITE_MAYBE_KERNEL_LOG(
257 tfLiteContext,
258 "TfLiteArmnnDelegate: Failed to add reshape for FullyConnected #%d node #%d: ",
259 operatorCode,
260 nodeIndex);
261 return kTfLiteError;
262 }
263 }
264
Sadik Armagan6e36a642020-11-10 21:18:41 +0000265 if (!tfLiteNodeParameters)
266 {
267 // No Activation
268 return kTfLiteOk;
269 }
Ryan OShea3ad2e142023-01-13 10:19:20 +0000270
271 // Check and Create Activation
Sadik Armagan6e36a642020-11-10 21:18:41 +0000272 return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100273}
274
Sadik Armagan6e36a642020-11-10 21:18:41 +0000275} // namespace armnnDelegate