blob: e8c13f2053ac909f2031698023e103ac9b010937 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagan6e36a642020-11-10 21:18:41 +00008#include "DelegateUtils.hpp"
Finn Williams6f9f9902020-11-13 13:23:15 +00009#include <armnn/utility/IgnoreUnused.hpp>
Sadik Armagan6e36a642020-11-10 21:18:41 +000010
Sadik Armagan62483be2020-10-23 17:14:43 +010011#include <tensorflow/lite/builtin_ops.h>
12#include <tensorflow/lite/c/builtin_op_data.h>
13#include <tensorflow/lite/c/common.h>
14#include <tensorflow/lite/minimal_logging.h>
15
16namespace armnnDelegate
17{
18
19TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t operatorCode)
24{
Sadik Armagan6e36a642020-11-10 21:18:41 +000025 auto numInputs = tfLiteNode->inputs->size;
26 if (numInputs < 2)
27 {
28 TF_LITE_MAYBE_KERNEL_LOG(
29 tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
30 2, numInputs, nodeIndex);
31 return kTfLiteError;
32 }
33 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Mike Kelly84d63782022-05-06 12:14:16 +010034 bool biasEnabled = IsOptionalOperandPresent(tfLiteNode, 2);
Sadik Armagan6e36a642020-11-10 21:18:41 +000035
36 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
37 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000038 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000039 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000040 return kTfLiteError;
41 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000042
Sadik Armagan6e36a642020-11-10 21:18:41 +000043 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000044 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000045 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000046 return kTfLiteError;
47 }
48
49 const TfLiteTensor& tfLiteWeightsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000050 if (!IsValid(tfLiteContext, tfLiteWeightsTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000051 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000052 return kTfLiteError;
53 }
54
Mike Kelly84d63782022-05-06 12:14:16 +010055 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
56 armnn::TensorInfo weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
57 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
Sadik Armagan6e36a642020-11-10 21:18:41 +000058
Sadik Armagan6e36a642020-11-10 21:18:41 +000059 // Fully Connected Layer accepts two dimensional weights input
60 int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions());
61 if (weightsDimension != 2)
62 {
63 TF_LITE_MAYBE_KERNEL_LOG(
64 tfLiteContext,
65 "TfLiteArmnnDelegate: Dimension #$d for Fully Connected weights is not supported by Armnn"
66 " in operator #%d node #%d: ", weightsDimension, operatorCode, nodeIndex);
67 return kTfLiteError;
68 }
69
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000070 bool isConstantWeights = tflite::IsConstantTensor(&tfLiteWeightsTensor);
71
Sadik Armagan6e36a642020-11-10 21:18:41 +000072 armnn::TensorInfo biasTensorInfo;
73 if (biasEnabled)
74 {
75 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000076 if (!IsValid(tfLiteContext, tfLiteBiasTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000077 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000078 return kTfLiteError;
79 }
Sadik Armagan6e36a642020-11-10 21:18:41 +000080 biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
81 }
82 else
83 {
84 biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
85 }
86
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +000087 armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +000088 if (inputTensorInfo.GetNumDimensions() > 2)
89 {
90 // Calculate reshape to flatten to 2D [batch_size, input_size]
91 std::vector<unsigned int> reshapedDimensions(2);
92 reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
93 reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
94
95 if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
96 {
97 TF_LITE_MAYBE_KERNEL_LOG(
98 tfLiteContext,
99 "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
100 reshapedDimensions[1], operatorCode, nodeIndex);
101 return kTfLiteError;
102 }
103
104 reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
105 }
106
Sadik Armagan6e36a642020-11-10 21:18:41 +0000107 armnn::FullyConnectedDescriptor descriptor;
108 descriptor.m_TransposeWeightMatrix = true;
109 descriptor.m_BiasEnabled = biasEnabled;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000110 descriptor.m_ConstantWeights = isConstantWeights;
Sadik Armagan6e36a642020-11-10 21:18:41 +0000111
112 bool isSupported = false;
113 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
114 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000115 FORWARD_LAYER_SUPPORT_FUNC("FULLY_CONNECTED",
Sadik Armagan6e36a642020-11-10 21:18:41 +0000116 tfLiteContext,
117 IsFullyConnectedSupported,
118 delegateData.m_Backends,
119 isSupported,
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000120 reshapedTensorInfo,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000121 outputTensorInfo,
122 weightsTensorInfo,
123 biasTensorInfo,
124 descriptor);
125 };
126
127 if (!delegateData.m_Network)
128 {
129 validateFunc(outputTensorInfo, isSupported);
130 return isSupported ? kTfLiteOk : kTfLiteError;
131 }
132
Matthew Sloyan81beae32021-07-13 19:46:11 +0100133 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor);
134 ARMNN_ASSERT(layer != nullptr);
135
136 // Add a constant layer for weights and biases if inputs are constant.
137 if (isConstantWeights)
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000138 {
139 auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
140 weightsTensorInfo,
141 armnn::Optional<armnn::PermutationVector&>());
Sadik Armagan6e36a642020-11-10 21:18:41 +0000142
Matthew Sloyan81beae32021-07-13 19:46:11 +0100143 armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(weightsTensor);
144
145 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
146 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsTensorInfo);
147 }
148
149 if (biasEnabled)
150 {
151 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
152 if(tflite::IsConstantTensor(&tfLiteBiasTensor))
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000153 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000154 auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
155 biasTensorInfo,
156 armnn::Optional<armnn::PermutationVector&>());
Matthew Sloyan81beae32021-07-13 19:46:11 +0100157
158 armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor);
159 ARMNN_ASSERT(biasLayer != nullptr);
160
161 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
162 biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000163 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000164 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000165
Sadik Armagan6e36a642020-11-10 21:18:41 +0000166 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
167 outputSlot.SetTensorInfo(outputTensorInfo);
168
169 armnn::IConnectableLayer* reshapeLayer = nullptr;
170 if (inputTensorInfo.GetNumDimensions() > 2)
171 {
172 // Add reshape to flatten to 2D [batch_size, input_size]
Sadik Armagan6e36a642020-11-10 21:18:41 +0000173 armnn::ReshapeDescriptor reshapeDescriptor;
174 reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
175 reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor);
176 ARMNN_ASSERT(reshapeLayer != nullptr);
177
178 reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
179
180 // Connect
181 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
182 reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
Matthew Sloyan81beae32021-07-13 19:46:11 +0100183
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000184 if (!descriptor.m_ConstantWeights)
185 {
186 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1));
Matthew Sloyan81beae32021-07-13 19:46:11 +0100187 }
188
189 if (biasEnabled && !tflite::IsConstantTensor(&tfLiteTensors[tfLiteNode->inputs->data[2]]))
190 {
191 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000192 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000193 delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
194 }
195
196 if (reshapeLayer == nullptr)
197 {
198 Connect(layer, tfLiteNode, delegateData);
199 }
200
Teresa Charlin1c717642020-11-25 18:34:51 +0000201 auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000202 if (!tfLiteNodeParameters)
203 {
204 // No Activation
205 return kTfLiteOk;
206 }
207
208 // Check Activation
209 TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
210 return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100211}
212
Sadik Armagan6e36a642020-11-10 21:18:41 +0000213} // namespace armnnDelegate