blob: e94304fb2148abd9cf96e91eaf2c5a9becb5a924 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagan6e36a642020-11-10 21:18:41 +00008#include "DelegateUtils.hpp"
Finn Williams6f9f9902020-11-13 13:23:15 +00009#include <armnn/utility/IgnoreUnused.hpp>
Sadik Armagan6e36a642020-11-10 21:18:41 +000010
Sadik Armagan62483be2020-10-23 17:14:43 +010011#include <tensorflow/lite/builtin_ops.h>
12#include <tensorflow/lite/c/builtin_op_data.h>
13#include <tensorflow/lite/c/common.h>
14#include <tensorflow/lite/minimal_logging.h>
15
16namespace armnnDelegate
17{
18
19TfLiteStatus VisitFullyConnectedOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t operatorCode)
24{
Sadik Armagan6e36a642020-11-10 21:18:41 +000025 auto numInputs = tfLiteNode->inputs->size;
26 if (numInputs < 2)
27 {
28 TF_LITE_MAYBE_KERNEL_LOG(
29 tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
30 2, numInputs, nodeIndex);
31 return kTfLiteError;
32 }
33 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
34 bool biasEnabled = (numInputs == 3);
35
36 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
37 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000038 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000039 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000040 return kTfLiteError;
41 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000042
Sadik Armagan6e36a642020-11-10 21:18:41 +000043 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000044 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000045 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000046 return kTfLiteError;
47 }
48
49 const TfLiteTensor& tfLiteWeightsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000050 if (!IsValid(tfLiteContext, tfLiteWeightsTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000051 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000052 return kTfLiteError;
53 }
54
55 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000056 armnn::TensorInfo weightsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteWeightsTensor);
Sadik Armagan6e36a642020-11-10 21:18:41 +000057 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
58
Sadik Armagan6e36a642020-11-10 21:18:41 +000059 // Fully Connected Layer accepts two dimensional weights input
60 int32_t weightsDimension = static_cast<int32_t>(weightsTensorInfo.GetNumDimensions());
61 if (weightsDimension != 2)
62 {
63 TF_LITE_MAYBE_KERNEL_LOG(
64 tfLiteContext,
65 "TfLiteArmnnDelegate: Dimension #$d for Fully Connected weights is not supported by Armnn"
66 " in operator #%d node #%d: ", weightsDimension, operatorCode, nodeIndex);
67 return kTfLiteError;
68 }
69
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000070 bool isConstantWeights = tflite::IsConstantTensor(&tfLiteWeightsTensor);
71
Sadik Armagan6e36a642020-11-10 21:18:41 +000072 armnn::TensorInfo biasTensorInfo;
73 if (biasEnabled)
74 {
75 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000076 if (!IsValid(tfLiteContext, tfLiteBiasTensor, operatorCode, nodeIndex))
Sadik Armagan6e36a642020-11-10 21:18:41 +000077 {
Sadik Armagan6e36a642020-11-10 21:18:41 +000078 return kTfLiteError;
79 }
Sadik Armagan6e36a642020-11-10 21:18:41 +000080 biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
81 }
82 else
83 {
84 biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
85 }
86
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +000087 armnn::TensorInfo reshapedTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +000088 if (inputTensorInfo.GetNumDimensions() > 2)
89 {
90 // Calculate reshape to flatten to 2D [batch_size, input_size]
91 std::vector<unsigned int> reshapedDimensions(2);
92 reshapedDimensions[1] = weightsTensorInfo.GetShape()[1];
93 reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
94
95 if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
96 {
97 TF_LITE_MAYBE_KERNEL_LOG(
98 tfLiteContext,
99 "TfLiteArmnnDelegate: Failed to deduce input tensor shape from filter size #%d #%d node #%d: ",
100 reshapedDimensions[1], operatorCode, nodeIndex);
101 return kTfLiteError;
102 }
103
104 reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
105 }
106
Sadik Armagan6e36a642020-11-10 21:18:41 +0000107 armnn::FullyConnectedDescriptor descriptor;
108 descriptor.m_TransposeWeightMatrix = true;
109 descriptor.m_BiasEnabled = biasEnabled;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000110 descriptor.m_ConstantWeights = isConstantWeights;
Sadik Armagan6e36a642020-11-10 21:18:41 +0000111
112 bool isSupported = false;
113 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
114 {
115 FORWARD_LAYER_SUPPORT_FUNC(__func__,
116 tfLiteContext,
117 IsFullyConnectedSupported,
118 delegateData.m_Backends,
119 isSupported,
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +0000120 reshapedTensorInfo,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000121 outputTensorInfo,
122 weightsTensorInfo,
123 biasTensorInfo,
124 descriptor);
125 };
126
127 if (!delegateData.m_Network)
128 {
129 validateFunc(outputTensorInfo, isSupported);
130 return isSupported ? kTfLiteOk : kTfLiteError;
131 }
132
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000133 armnn::Optional<armnn::ConstTensor> optionalWeights = armnn::EmptyOptional();
134 armnn::Optional<armnn::ConstTensor> optionalBiases = armnn::EmptyOptional();
135 if(descriptor.m_ConstantWeights)
136 {
137 auto weightsTensor = CreateConstTensor(&tfLiteWeightsTensor,
138 weightsTensorInfo,
139 armnn::Optional<armnn::PermutationVector&>());
140 optionalWeights = armnn::Optional<armnn::ConstTensor>(weightsTensor);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000141
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000142 if (biasEnabled)
143 {
144 const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
145 auto biasTensor = CreateConstTensor(&tfLiteBiasTensor,
146 biasTensorInfo,
147 armnn::Optional<armnn::PermutationVector&>());
148 optionalBiases = armnn::Optional<armnn::ConstTensor>(biasTensor);
149 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000150 }
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000151
152 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFullyConnectedLayer(descriptor,
153 optionalWeights,
154 optionalBiases);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000155 ARMNN_ASSERT(layer != nullptr);
156
157 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
158 outputSlot.SetTensorInfo(outputTensorInfo);
159
160 armnn::IConnectableLayer* reshapeLayer = nullptr;
161 if (inputTensorInfo.GetNumDimensions() > 2)
162 {
163 // Add reshape to flatten to 2D [batch_size, input_size]
Sadik Armagan6e36a642020-11-10 21:18:41 +0000164 armnn::ReshapeDescriptor reshapeDescriptor;
165 reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
166 reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor);
167 ARMNN_ASSERT(reshapeLayer != nullptr);
168
169 reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
170
171 // Connect
172 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(reshapeLayer->GetInputSlot(0));
173 reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000174 if (!descriptor.m_ConstantWeights)
175 {
176 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(1));
177 if (biasEnabled)
178 {
179 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(2));
180 }
181 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000182 delegateData.m_OutputSlotForNode[tfLiteNode->outputs->data[0]] = &outputSlot;
183 }
184
185 if (reshapeLayer == nullptr)
186 {
187 Connect(layer, tfLiteNode, delegateData);
188 }
189
Teresa Charlin1c717642020-11-25 18:34:51 +0000190 auto* tfLiteNodeParameters = reinterpret_cast<TfLiteFullyConnectedParams*>(tfLiteNode->builtin_data);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000191 if (!tfLiteNodeParameters)
192 {
193 // No Activation
194 return kTfLiteOk;
195 }
196
197 // Check Activation
198 TfLiteFusedActivation activationType = tfLiteNodeParameters->activation;
199 return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100200}
201
Sadik Armagan6e36a642020-11-10 21:18:41 +0000202} // namespace armnnDelegate