blob: 8f9a4e4ba0a48f933b875e5d026c0ced6e896d37 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Sadik Armagan90a119b2022-08-05 16:12:49 +01002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Finn Williams6f9f9902020-11-13 13:23:15 +00008#include <armnn/utility/IgnoreUnused.hpp>
9
David Monahan1670b0c2020-11-18 14:40:27 +000010#include "DelegateUtils.hpp"
11
Sadik Armagan62483be2020-10-23 17:14:43 +010012#include <tensorflow/lite/builtin_ops.h>
13#include <tensorflow/lite/c/builtin_op_data.h>
14#include <tensorflow/lite/c/common.h>
15#include <tensorflow/lite/minimal_logging.h>
David Monahan1670b0c2020-11-18 14:40:27 +000016#include <numeric>
Sadik Armagan62483be2020-10-23 17:14:43 +010017
18namespace armnnDelegate
19{
20
Sadik Armagan937565b2021-04-21 14:03:28 +010021TfLiteStatus VisitCastOperator(DelegateData& delegateData,
22 TfLiteContext* tfLiteContext,
23 TfLiteNode* tfLiteNode,
24 int nodeIndex,
25 int32_t operatorCode)
26{
27 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
28 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
29
30 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
31 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
32 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
33 {
34 return kTfLiteError;
35 }
36
37 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
38 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
39 {
40 return kTfLiteError;
41 }
42
43 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010044 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagan937565b2021-04-21 14:03:28 +010045
46 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +010047 armnn::BackendId setBackend;
Sadik Armagan937565b2021-04-21 14:03:28 +010048 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
49 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000050 FORWARD_LAYER_SUPPORT_FUNC("CAST",
Sadik Armagan937565b2021-04-21 14:03:28 +010051 tfLiteContext,
52 IsCastSupported,
53 delegateData.m_Backends,
54 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010055 setBackend,
Sadik Armagan937565b2021-04-21 14:03:28 +010056 inputTensorInfo,
57 outInfo);
58 };
59
60 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
61 // support for the operator
62 // If supported, VisitCastOperator will be called again to add the layer to the network as seen further below
63 if (!delegateData.m_Network)
64 {
65 validateFunc(outputTensorInfo, isSupported);
66 return isSupported ? kTfLiteOk : kTfLiteError;
67 }
68
69 // Add a Cast layer
70 armnn::IConnectableLayer* layer = delegateData.m_Network->AddCastLayer();
Cathal Corbett53837672022-09-01 11:34:37 +010071 layer->SetBackendId(setBackend);
Sadik Armagan937565b2021-04-21 14:03:28 +010072 ARMNN_ASSERT(layer != nullptr);
73
74 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
75 outputSlot.SetTensorInfo(outputTensorInfo);
76
77 // Connect
78 return Connect(layer, tfLiteNode, delegateData);
79}
80
81
David Monahan1670b0c2020-11-18 14:40:27 +000082TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo,
Matthew Sloyanf00f6c22020-12-07 13:33:24 +000083 const std::vector<int32_t>& targetShape,
84 armnn::ReshapeDescriptor& reshapeDesc)
David Monahan1670b0c2020-11-18 14:40:27 +000085{
86 std::vector<unsigned int> outputDims(targetShape.begin(), targetShape.end());
87 const auto stretchDim = std::find(targetShape.begin(), targetShape.end(), -1);
88
89 if (stretchDim != targetShape.end())
90 {
91 if (std::find(std::next(stretchDim), targetShape.end(), -1) != targetShape.end())
92 {
93 // Return kTfLiteError and log the error after returning
94 return kTfLiteError;
95 }
96
97 auto targetNumElements =
98 armnn::numeric_cast<unsigned int>(
99 std::accumulate(targetShape.begin(), targetShape.end(), -1, std::multiplies<int32_t>()));
100
101 auto stretchIndex = static_cast<size_t>(std::distance(targetShape.begin(), stretchDim));
102 outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
103 }
104
105 armnn::TensorShape outputShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()),
106 outputDims.data());
107 reshapeDesc.m_TargetShape = outputShape;
108 return kTfLiteOk;
109}
110
Sadik Armagan62483be2020-10-23 17:14:43 +0100111TfLiteStatus VisitReshapeOperator(DelegateData& delegateData,
112 TfLiteContext* tfLiteContext,
113 TfLiteNode* tfLiteNode,
114 int nodeIndex,
115 int32_t operatorCode)
116{
David Monahan1670b0c2020-11-18 14:40:27 +0000117 auto numInputs = tfLiteNode->inputs->size;
Finn Williams6f9f9902020-11-13 13:23:15 +0000118
David Monahan1670b0c2020-11-18 14:40:27 +0000119 if (numInputs == 2)
120 {
121 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
122 }
123 else
124 {
125 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
126 }
127 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
128
129 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
130 const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000131 if (!IsValid(tfLiteContext, tfLiteInputTensor0, operatorCode, nodeIndex))
David Monahan1670b0c2020-11-18 14:40:27 +0000132 {
David Monahan1670b0c2020-11-18 14:40:27 +0000133 return kTfLiteError;
134 }
135
136 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000137 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
David Monahan1670b0c2020-11-18 14:40:27 +0000138 {
David Monahan1670b0c2020-11-18 14:40:27 +0000139 return kTfLiteError;
140 }
141
142 const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
Sadik Armagan90a119b2022-08-05 16:12:49 +0100143 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
David Monahan1670b0c2020-11-18 14:40:27 +0000144
145 armnn::ReshapeDescriptor reshapeDesc;
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000146 std::vector<int32_t> targetShape;
Finn Williamsf806c4d2021-02-22 15:13:12 +0000147
148 TfLiteReshapeParams* reshapeOptions = reinterpret_cast<TfLiteReshapeParams*>(tfLiteNode->builtin_data);
David Monahan1670b0c2020-11-18 14:40:27 +0000149
150 // The new shape can be defined by either a second input tensor or by a builtin option, we need to check for both.
Finn Williamsf806c4d2021-02-22 15:13:12 +0000151 // Options might be set without valid data. we need to check the dimensions are in a valid range.
152 if (reshapeOptions && reshapeOptions->num_dimensions > 0 && reshapeOptions->num_dimensions <= 8)
153 {
154 for (int i=0; i < reshapeOptions->num_dimensions; ++i)
155 {
156 targetShape.push_back(reshapeOptions->shape[i]);
157 }
158 }
159 else if (numInputs == 2)
David Monahan1670b0c2020-11-18 14:40:27 +0000160 {
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000161 // Get shape from the second input tensor
162 const TfLiteTensor& tfLiteShapeInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000163 if (!IsValid(tfLiteContext, tfLiteShapeInputTensor, operatorCode, nodeIndex))
David Monahane03d9c22020-11-20 09:58:54 +0000164 {
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000165 return kTfLiteError;
166 }
167
168 if (tfLiteShapeInputTensor.dims->size != 1)
169 {
170 TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
171 "TfLiteArmnnDelegate: Target 'shape' input is not a 1D tensor in "
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000172 "operator #%d node #%d: Falling back to TfLiteOptions.",
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000173 operatorCode, nodeIndex);
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000174 }
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000175 else
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000176 {
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000177 // Get the shape data out of the input tensor
178 auto* shapeTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteShapeInputTensor);
179 auto shapeTensorNumValues = tfLiteShapeInputTensor.dims->data[0];
180 for (auto i=0; i < shapeTensorNumValues; ++i)
181 {
182 targetShape.push_back(*(shapeTensorDataPtr+i));
183 }
David Monahane03d9c22020-11-20 09:58:54 +0000184 }
185 }
Finn Williamsf806c4d2021-02-22 15:13:12 +0000186 else
David Monahane03d9c22020-11-20 09:58:54 +0000187 {
Finn Williamsf806c4d2021-02-22 15:13:12 +0000188 TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
189 "Target shape not defined in reshape parameters or input tensor. "
190 "At least one method required in operator #%d node #%d: ",
191 operatorCode, nodeIndex);
192 return kTfLiteError;
David Monahan1670b0c2020-11-18 14:40:27 +0000193 }
David Monahane03d9c22020-11-20 09:58:54 +0000194
195 // Use the data to create the required tensor shape.
196 if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk)
David Monahan1670b0c2020-11-18 14:40:27 +0000197 {
198 TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
David Monahane03d9c22020-11-20 09:58:54 +0000199 "TfLiteArmnnDelegate: At most one component of shape can be -1 in: "
200 "operator #%d node #%d: ",
David Monahan1670b0c2020-11-18 14:40:27 +0000201 operatorCode, nodeIndex);
David Monahane03d9c22020-11-20 09:58:54 +0000202 return kTfLiteError;
203 }
204
205 if (reshapeDesc.m_TargetShape.GetNumElements() != inputTensorInfo0.GetNumElements())
206 {
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000207 TF_LITE_MAYBE_KERNEL_LOG(
208 tfLiteContext,
209 "TfLiteArmnnDelegate: Reshape, number of elements in output shape does not match input "
210 "operator #%d node #%d: ",
211 operatorCode, nodeIndex);
David Monahane03d9c22020-11-20 09:58:54 +0000212 return kTfLiteError;
David Monahan1670b0c2020-11-18 14:40:27 +0000213 }
214
215 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +0100216 armnn::BackendId setBackend;
David Monahan1670b0c2020-11-18 14:40:27 +0000217 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
218 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000219 FORWARD_LAYER_SUPPORT_FUNC("RESHAPE",
David Monahan1670b0c2020-11-18 14:40:27 +0000220 tfLiteContext,
221 IsReshapeSupported,
222 delegateData.m_Backends,
223 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100224 setBackend,
David Monahan1670b0c2020-11-18 14:40:27 +0000225 inputTensorInfo0,
226 outInfo,
227 reshapeDesc);
228 };
229
230 if (!delegateData.m_Network)
231 {
232 validateFunc(outputTensorInfo, isSupported);
233 return isSupported ? kTfLiteOk : kTfLiteError;
234 }
235
236 armnn::IConnectableLayer* layer = delegateData.m_Network->AddReshapeLayer(reshapeDesc);
Cathal Corbett53837672022-09-01 11:34:37 +0100237 layer->SetBackendId(setBackend);
David Monahan1670b0c2020-11-18 14:40:27 +0000238 ARMNN_ASSERT(layer != nullptr);
239
240 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
241 outputSlot.SetTensorInfo(outputTensorInfo);
242
243 // Connect
244 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100245}
246
247TfLiteStatus VisitSqueezeOperator(DelegateData& delegateData,
248 TfLiteContext* tfLiteContext,
249 TfLiteNode* tfLiteNode,
250 int nodeIndex,
251 int32_t operatorCode)
252{
Finn Williams6f9f9902020-11-13 13:23:15 +0000253 armnn::IgnoreUnused(delegateData,
254 tfLiteContext,
255 tfLiteNode,
256 nodeIndex,
257 operatorCode);
258
Sadik Armagan62483be2020-10-23 17:14:43 +0100259 return kTfLiteError;
260}
261
262TfLiteStatus VisitExpandDimsOperator(DelegateData& delegateData,
263 TfLiteContext* tfLiteContext,
264 TfLiteNode* tfLiteNode,
265 int nodeIndex,
266 int32_t operatorCode)
267{
Finn Williams6f9f9902020-11-13 13:23:15 +0000268 armnn::IgnoreUnused(delegateData,
269 tfLiteContext,
270 tfLiteNode,
271 nodeIndex,
272 operatorCode);
273
Sadik Armagan62483be2020-10-23 17:14:43 +0100274 return kTfLiteError;
275}
276
277} // namespace armnnDelegate