blob: a2960e299bbb82e0af10a156f0cd6eb9fae0e4af [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Sadik Armagan90a119b2022-08-05 16:12:49 +01002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagan6e36a642020-11-10 21:18:41 +00008#include "DelegateUtils.hpp"
Finn Williams6f9f9902020-11-13 13:23:15 +00009#include <armnn/utility/IgnoreUnused.hpp>
Sadik Armagan6e36a642020-11-10 21:18:41 +000010
Sadik Armagan62483be2020-10-23 17:14:43 +010011#include <tensorflow/lite/builtin_ops.h>
12#include <tensorflow/lite/c/builtin_op_data.h>
13#include <tensorflow/lite/c/common.h>
14#include <tensorflow/lite/minimal_logging.h>
15
16namespace armnnDelegate
17{
18
19TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t operatorCode)
24{
Sadik Armagan6e36a642020-11-10 21:18:41 +000025 auto numInputs = tfLiteNode->inputs->size;
26 if (numInputs < 2)
27 {
28 TF_LITE_MAYBE_KERNEL_LOG(
29 tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
30 2, numInputs, nodeIndex);
31 return kTfLiteError;
32 }
33 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Mike Kelly84d63782022-05-06 12:14:16 +010034 bool biasEnabled = IsOptionalOperandPresent(tfLiteNode, 2);
Sadik Armagan6e36a642020-11-10 21:18:41 +000035
36 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
37 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000038 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000039 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000040 return kTfLiteError;
41 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000042
Sadik Armagan6e36a642020-11-10 21:18:41 +000043 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000044 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000045 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000046 return kTfLiteError;
47 }
48
49 const TfLiteTensor& tfLiteWeightsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000050 if (!IsValid(tfLiteContext, tfLiteWeightsTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000051 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000052 return kTfLiteError;
53 }
54
Mike Kelly84d63782022-05-06 12:14:16 +010055 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
56 armnn::TensorInfo weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010057 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagan6e36a642020-11-10 21:18:41 +000058
Sadik Armagan6e36a642020-11-10 21:18:41 +000059 // Fully Connected Layer accepts two dimensional weights input
60 int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions());
61 if (weightsDimension != 2)
62 {
63 TF_LITE_MAYBE_KERNEL_LOG(
64 tfLiteContext,
65 "TfLiteArmnnDelegate: Dimension #$d for Fully Connected weights is not supported by Armnn"
66 " in operator #%d node #%d: ", weightsDimension, operatorCode, nodeIndex);
67 return kTfLiteError;
68 }
69
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000070 bool isConstantWeights = tflite::IsConstantTensor(&tfLiteWeightsTensor);
71
Sadik Armagan6e36a642020-11-10 21:18:41 +000072 armnn::TensorInfo biasTensorInfo;
73 if (biasEnabled)
74 {
75 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000076 if (!IsValid(tfLiteContext, tfLiteBiasTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000077 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000078 return kTfLiteError;
79 }
Sadik Armagan6e36a642020-11-10 21:18:41 +000080 biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
81 }
82 else
83 {
84 biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
85 }
86
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +000087 armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +000088 if (inputTensorInfo.GetNumDimensions() > 2)
89 {
90 // Calculate reshape to flatten to 2D [batch_size, input_size]
91 std::vector<unsigned int> reshapedDimensions(2);
92 reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
93 reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
94
95 if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
96 {
97 TF_LITE_MAYBE_KERNEL_LOG(
98 tfLiteContext,
99 "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
100 reshapedDimensions[1], operatorCode, nodeIndex);
101 return kTfLiteError;
102 }
103
104 reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
105 }
106
Sadik Armagan6e36a642020-11-10 21:18:41 +0000107 armnn::FullyConnectedDescriptor descriptor;
108 descriptor.m_TransposeWeightMatrix = true;
109 descriptor.m_BiasEnabled = biasEnabled;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000110 descriptor.m_ConstantWeights = isConstantWeights;
Sadik Armagan6e36a642020-11-10 21:18:41 +0000111
112 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +0100113 armnn::BackendId setBackend;
Sadik Armagan6e36a642020-11-10 21:18:41 +0000114 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
115 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000116 FORWARD_LAYER_SUPPORT_FUNC("FULLY_CONNECTED",
Sadik Armagan6e36a642020-11-10 21:18:41 +0000117 tfLiteContext,
118 IsFullyConnectedSupported,
119 delegateData.m_Backends,
120 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100121 setBackend,
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000122 reshapedTensorInfo,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000123 outputTensorInfo,
124 weightsTensorInfo,
125 biasTensorInfo,
126 descriptor);
127 };
128
129 if (!delegateData.m_Network)
130 {
131 validateFunc(outputTensorInfo, isSupported);
132 return isSupported ? kTfLiteOk : kTfLiteError;
133 }
134
Matthew Sloyan81beae32021-07-13 19:46:11 +0100135 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100136 layer->SetBackendId(setBackend);
Matthew Sloyan81beae32021-07-13 19:46:11 +0100137 ARMNN_ASSERT(layer != nullptr);
138
139 // Add a constant layer for weights and biases if inputs are constant.
140 if (isConstantWeights)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000141 {
142 auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
143 weightsTensorInfo,
144 armnn::Optional<armnn::PermutationVector&>());
Sadik Armagan6e36a642020-11-10 21:18:41 +0000145
Matthew Sloyan81beae32021-07-13 19:46:11 +0100146 armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(weightsTensor);
147
148 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
149 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsTensorInfo);
150 }
151
152 if (biasEnabled)
153 {
154 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
155 if(tflite::IsConstantTensor(&tfLiteBiasTensor))
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000156 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000157 auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
158 biasTensorInfo,
159 armnn::Optional<armnn::PermutationVector&>());
Matthew Sloyan81beae32021-07-13 19:46:11 +0100160
161 armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor);
162 ARMNN_ASSERT(biasLayer != nullptr);
163
164 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
165 biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000166 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000167 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000168
Sadik Armagan6e36a642020-11-10 21:18:41 +0000169 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
170 outputSlot.SetTensorInfo(outputTensorInfo);
171
172 armnn::IConnectableLayer* reshapeLayer = nullptr;
173 if (inputTensorInfo.GetNumDimensions() > 2)
174 {
175 // Add reshape to flatten to 2D [batch_size, input_size]
Sadik Armagan6e36a642020-11-10 21:18:41 +0000176 armnn::ReshapeDescriptor reshapeDescriptor;
177 reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
178 reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor);
179 ARMNN_ASSERT(reshapeLayer != nullptr);
180
181 reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
182
183 // Connect
184 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
185 reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
Matthew Sloyan81beae32021-07-13 19:46:11 +0100186
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000187 if (!descriptor.m_ConstantWeights)
188 {
189 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1));
Matthew Sloyan81beae32021-07-13 19:46:11 +0100190 }
191
192 if (biasEnabled && !tflite::IsConstantTensor(&tfLiteTensors[tfLiteNode->inputs->data[2]]))
193 {
194 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000195 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000196 delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
197 }
198
199 if (reshapeLayer == nullptr)
200 {
201 Connect(layer, tfLiteNode, delegateData);
202 }
203
Teresa Charlin1c717642020-11-25 18:34:51 +0000204 auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000205 if (!tfLiteNodeParameters)
206 {
207 // No Activation
208 return kTfLiteOk;
209 }
210
211 // Check Activation
212 TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
213 return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100214}
215
Sadik Armagan6e36a642020-11-10 21:18:41 +0000216} // namespace armnnDelegate