blob: 2b45c48a896277d13422143a23d3bed456caad1b [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagan6e36a642020-11-10 21:18:41 +00008#include "DelegateUtils.hpp"
Finn Williams6f9f9902020-11-13 13:23:15 +00009#include <armnn/utility/IgnoreUnused.hpp>
Sadik Armagan6e36a642020-11-10 21:18:41 +000010
Sadik Armagan62483be2020-10-23 17:14:43 +010011#include <tensorflow/lite/builtin_ops.h>
12#include <tensorflow/lite/c/builtin_op_data.h>
13#include <tensorflow/lite/c/common.h>
14#include <tensorflow/lite/minimal_logging.h>
15
16namespace armnnDelegate
17{
18
19TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t operatorCode)
24{
Sadik Armagan6e36a642020-11-10 21:18:41 +000025 auto numInputs = tfLiteNode->inputs->size;
26 if (numInputs < 2)
27 {
28 TF_LITE_MAYBE_KERNEL_LOG(
29 tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
30 2, numInputs, nodeIndex);
31 return kTfLiteError;
32 }
33 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
34 bool biasEnabled = (numInputs == 3);
35
36 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
37 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000038 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000039 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000040 return kTfLiteError;
41 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000042
Sadik Armagan6e36a642020-11-10 21:18:41 +000043 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000044 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000045 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000046 return kTfLiteError;
47 }
48
49 const TfLiteTensor& tfLiteWeightsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000050 if (!IsValid(tfLiteContext, tfLiteWeightsTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000051 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000052 return kTfLiteError;
53 }
54
55 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000056 armnn::TensorInfo weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
Sadik Armagan6e36a642020-11-10 21:18:41 +000057 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
58
Sadik Armagan6e36a642020-11-10 21:18:41 +000059 // Fully Connected Layer accepts two dimensional weights input
60 int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions());
61 if (weightsDimension != 2)
62 {
63 TF_LITE_MAYBE_KERNEL_LOG(
64 tfLiteContext,
65 "TfLiteArmnnDelegate: Dimension #$d for Fully Connected weights is not supported by Armnn"
66 " in operator #%d node #%d: ", weightsDimension, operatorCode, nodeIndex);
67 return kTfLiteError;
68 }
69
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000070 bool isConstantWeights = tflite::IsConstantTensor(&tfLiteWeightsTensor);
71
Sadik Armagan6e36a642020-11-10 21:18:41 +000072 armnn::TensorInfo biasTensorInfo;
73 if (biasEnabled)
74 {
75 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000076 if (!IsValid(tfLiteContext, tfLiteBiasTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000077 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000078 return kTfLiteError;
79 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000080 if ((isConstantWeights && !tflite::IsConstantTensor(&tfLiteBiasTensor))
81 || (!isConstantWeights && tflite::IsConstantTensor(&tfLiteBiasTensor)))
Sadik Armagan6e36a642020-11-10 21:18:41 +000082 {
83 TF_LITE_MAYBE_KERNEL_LOG(
84 tfLiteContext,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000085 "TfLiteArmnnDelegate: Weights and bias are not compatible"
86 " in operator #%d node #%d: ", operatorCode, nodeIndex);
Sadik Armagan6e36a642020-11-10 21:18:41 +000087 return kTfLiteError;
88 }
89 biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
90 }
91 else
92 {
93 biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
94 }
95
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +000096 armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +000097 if (inputTensorInfo.GetNumDimensions() > 2)
98 {
99 // Calculate reshape to flatten to 2D [batch_size, input_size]
100 std::vector<unsigned int> reshapedDimensions(2);
101 reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
102 reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
103
104 if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
105 {
106 TF_LITE_MAYBE_KERNEL_LOG(
107 tfLiteContext,
108 "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
109 reshapedDimensions[1], operatorCode, nodeIndex);
110 return kTfLiteError;
111 }
112
113 reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
114 }
115
Sadik Armagan6e36a642020-11-10 21:18:41 +0000116 armnn::FullyConnectedDescriptor descriptor;
117 descriptor.m_TransposeWeightMatrix = true;
118 descriptor.m_BiasEnabled = biasEnabled;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000119 descriptor.m_ConstantWeights = isConstantWeights;
Sadik Armagan6e36a642020-11-10 21:18:41 +0000120
121 bool isSupported = false;
122 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
123 {
124 FORWARD_LAYER_SUPPORT_FUNC(__func__,
125 tfLiteContext,
126 IsFullyConnectedSupported,
127 delegateData.m_Backends,
128 isSupported,
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000129 reshapedTensorInfo,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000130 outputTensorInfo,
131 weightsTensorInfo,
132 biasTensorInfo,
133 descriptor);
134 };
135
136 if (!delegateData.m_Network)
137 {
138 validateFunc(outputTensorInfo, isSupported);
139 return isSupported ? kTfLiteOk : kTfLiteError;
140 }
141
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000142 armnn::Optional<armnn::ConstTensor> optionalWeights = armnn::EmptyOptional();
143 armnn::Optional<armnn::ConstTensor> optionalBiases = armnn::EmptyOptional();
144 if(descriptor.m_ConstantWeights)
145 {
146 auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
147 weightsTensorInfo,
148 armnn::Optional<armnn::PermutationVector&>());
149 optionalWeights = armnn::Optional<armnn::ConstTensor>(weightsTensor);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000150
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000151 if (biasEnabled)
152 {
153 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
154 auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
155 biasTensorInfo,
156 armnn::Optional<armnn::PermutationVector&>());
157 optionalBiases = armnn::Optional<armnn::ConstTensor>(biasTensor);
158 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000159 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000160
161 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor,
162 optionalWeights,
163 optionalBiases);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000164 ARMNN_ASSERT(layer != nullptr);
165
166 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
167 outputSlot.SetTensorInfo(outputTensorInfo);
168
169 armnn::IConnectableLayer* reshapeLayer = nullptr;
170 if (inputTensorInfo.GetNumDimensions() > 2)
171 {
172 // Add reshape to flatten to 2D [batch_size, input_size]
Sadik Armagan6e36a642020-11-10 21:18:41 +0000173 armnn::ReshapeDescriptor reshapeDescriptor;
174 reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
175 reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor);
176 ARMNN_ASSERT(reshapeLayer != nullptr);
177
178 reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
179
180 // Connect
181 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
182 reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000183 if (!descriptor.m_ConstantWeights)
184 {
185 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1));
186 if (biasEnabled)
187 {
188 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
189 }
190 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000191 delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
192 }
193
194 if (reshapeLayer == nullptr)
195 {
196 Connect(layer, tfLiteNode, delegateData);
197 }
198
Teresa Charlin1c717642020-11-25 18:34:51 +0000199 auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000200 if (!tfLiteNodeParameters)
201 {
202 // No Activation
203 return kTfLiteOk;
204 }
205
206 // Check Activation
207 TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
208 return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100209}
210
Sadik Armagan6e36a642020-11-10 21:18:41 +0000211} // namespace armnnDelegate