blob: 6b10e448e7e1af64391d1d2da7b212b17e7aa986 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
David Monahan1670b0c2020-11-18 14:40:27 +00009
Sadik Armagan62483be2020-10-23 17:14:43 +010010#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
Sadik Armagan937565b2021-04-21 14:03:28 +010018TfLiteStatus VisitCastOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 TfLiteNode* tfLiteNode,
21 int nodeIndex,
22 int32_t operatorCode)
23{
24 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26
27 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
28 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
29 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
30 {
31 return kTfLiteError;
32 }
33
34 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
35 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
36 {
37 return kTfLiteError;
38 }
39
40 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010041 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Sadik Armagan937565b2021-04-21 14:03:28 +010042
43 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +010044 armnn::BackendId setBackend;
Sadik Armagan937565b2021-04-21 14:03:28 +010045 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
46 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000047 FORWARD_LAYER_SUPPORT_FUNC("CAST",
Sadik Armagan937565b2021-04-21 14:03:28 +010048 tfLiteContext,
49 IsCastSupported,
50 delegateData.m_Backends,
51 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010052 setBackend,
Sadik Armagan937565b2021-04-21 14:03:28 +010053 inputTensorInfo,
54 outInfo);
55 };
56
57 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
58 // support for the operator
59 // If supported, VisitCastOperator will be called again to add the layer to the network as seen further below
60 if (!delegateData.m_Network)
61 {
62 validateFunc(outputTensorInfo, isSupported);
63 return isSupported ? kTfLiteOk : kTfLiteError;
64 }
65
66 // Add a Cast layer
Mike Kelly07169c82023-08-02 13:23:09 +010067 auto layerName = GetLayerName(armnn::LayerType::Cast, nodeIndex);
68 armnn::IConnectableLayer* layer = delegateData.m_Network->AddCastLayer(layerName.c_str());
Cathal Corbett53837672022-09-01 11:34:37 +010069 layer->SetBackendId(setBackend);
Sadik Armagan937565b2021-04-21 14:03:28 +010070 ARMNN_ASSERT(layer != nullptr);
71
72 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
73 outputSlot.SetTensorInfo(outputTensorInfo);
74
Ryan OShea4c231de2023-01-17 15:19:20 +000075 // try to connect the Constant Inputs if there are any
Mike Kelly07169c82023-08-02 13:23:09 +010076 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Ryan OShea4c231de2023-01-17 15:19:20 +000077 {
78 return kTfLiteError;
79 }
80
Sadik Armagan937565b2021-04-21 14:03:28 +010081 // Connect
82 return Connect(layer, tfLiteNode, delegateData);
83}
84
Sadik Armagan62483be2020-10-23 17:14:43 +010085TfLiteStatus VisitReshapeOperator(DelegateData& delegateData,
86 TfLiteContext* tfLiteContext,
87 TfLiteNode* tfLiteNode,
88 int nodeIndex,
89 int32_t operatorCode)
90{
David Monahan1670b0c2020-11-18 14:40:27 +000091 auto numInputs = tfLiteNode->inputs->size;
Finn Williams6f9f9902020-11-13 13:23:15 +000092
David Monahan1670b0c2020-11-18 14:40:27 +000093 if (numInputs == 2)
94 {
95 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
96 }
97 else
98 {
99 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
100 }
101 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
102
103 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
104 const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000105 if (!IsValid(tfLiteContext, tfLiteInputTensor0, operatorCode, nodeIndex))
David Monahan1670b0c2020-11-18 14:40:27 +0000106 {
David Monahan1670b0c2020-11-18 14:40:27 +0000107 return kTfLiteError;
108 }
109
110 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000111 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
David Monahan1670b0c2020-11-18 14:40:27 +0000112 {
David Monahan1670b0c2020-11-18 14:40:27 +0000113 return kTfLiteError;
114 }
115
116 const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
Sadik Armagan90a119b2022-08-05 16:12:49 +0100117 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
David Monahan1670b0c2020-11-18 14:40:27 +0000118
119 armnn::ReshapeDescriptor reshapeDesc;
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000120 std::vector<int32_t> targetShape;
Finn Williamsf806c4d2021-02-22 15:13:12 +0000121
122 TfLiteReshapeParams* reshapeOptions = reinterpret_cast<TfLiteReshapeParams*>(tfLiteNode->builtin_data);
David Monahan1670b0c2020-11-18 14:40:27 +0000123
124 // The new shape can be defined by either a second input tensor or by a builtin option, we need to check for both.
Finn Williamsf806c4d2021-02-22 15:13:12 +0000125 // Options might be set without valid data. we need to check the dimensions are in a valid range.
126 if (reshapeOptions && reshapeOptions->num_dimensions > 0 && reshapeOptions->num_dimensions <= 8)
127 {
128 for (int i=0; i < reshapeOptions->num_dimensions; ++i)
129 {
130 targetShape.push_back(reshapeOptions->shape[i]);
131 }
132 }
133 else if (numInputs == 2)
David Monahan1670b0c2020-11-18 14:40:27 +0000134 {
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000135 // Get shape from the second input tensor
136 const TfLiteTensor& tfLiteShapeInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000137 if (!IsValid(tfLiteContext, tfLiteShapeInputTensor, operatorCode, nodeIndex))
David Monahane03d9c22020-11-20 09:58:54 +0000138 {
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000139 return kTfLiteError;
140 }
141
142 if (tfLiteShapeInputTensor.dims->size != 1)
143 {
144 TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
145 "TfLiteArmnnDelegate: Target 'shape' input is not a 1D tensor in "
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000146 "operator #%d node #%d: Falling back to TfLiteOptions.",
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000147 operatorCode, nodeIndex);
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000148 }
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000149 else
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000150 {
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000151 // Get the shape data out of the input tensor
152 auto* shapeTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteShapeInputTensor);
153 auto shapeTensorNumValues = tfLiteShapeInputTensor.dims->data[0];
154 for (auto i=0; i < shapeTensorNumValues; ++i)
155 {
156 targetShape.push_back(*(shapeTensorDataPtr+i));
157 }
David Monahane03d9c22020-11-20 09:58:54 +0000158 }
159 }
Finn Williamsf806c4d2021-02-22 15:13:12 +0000160 else
David Monahane03d9c22020-11-20 09:58:54 +0000161 {
Finn Williamsf806c4d2021-02-22 15:13:12 +0000162 TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
163 "Target shape not defined in reshape parameters or input tensor. "
164 "At least one method required in operator #%d node #%d: ",
165 operatorCode, nodeIndex);
166 return kTfLiteError;
David Monahan1670b0c2020-11-18 14:40:27 +0000167 }
David Monahane03d9c22020-11-20 09:58:54 +0000168
169 // Use the data to create the required tensor shape.
170 if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk)
David Monahan1670b0c2020-11-18 14:40:27 +0000171 {
172 TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
David Monahane03d9c22020-11-20 09:58:54 +0000173 "TfLiteArmnnDelegate: At most one component of shape can be -1 in: "
174 "operator #%d node #%d: ",
David Monahan1670b0c2020-11-18 14:40:27 +0000175 operatorCode, nodeIndex);
David Monahane03d9c22020-11-20 09:58:54 +0000176 return kTfLiteError;
177 }
178
179 if (reshapeDesc.m_TargetShape.GetNumElements() != inputTensorInfo0.GetNumElements())
180 {
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000181 TF_LITE_MAYBE_KERNEL_LOG(
182 tfLiteContext,
183 "TfLiteArmnnDelegate: Reshape, number of elements in output shape does not match input "
184 "operator #%d node #%d: ",
185 operatorCode, nodeIndex);
David Monahane03d9c22020-11-20 09:58:54 +0000186 return kTfLiteError;
David Monahan1670b0c2020-11-18 14:40:27 +0000187 }
188
189 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +0100190 armnn::BackendId setBackend;
David Monahan1670b0c2020-11-18 14:40:27 +0000191 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
192 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000193 FORWARD_LAYER_SUPPORT_FUNC("RESHAPE",
David Monahan1670b0c2020-11-18 14:40:27 +0000194 tfLiteContext,
195 IsReshapeSupported,
196 delegateData.m_Backends,
197 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100198 setBackend,
David Monahan1670b0c2020-11-18 14:40:27 +0000199 inputTensorInfo0,
200 outInfo,
201 reshapeDesc);
202 };
203
204 if (!delegateData.m_Network)
205 {
206 validateFunc(outputTensorInfo, isSupported);
207 return isSupported ? kTfLiteOk : kTfLiteError;
208 }
209
Mike Kelly07169c82023-08-02 13:23:09 +0100210 auto layerName = GetLayerName(armnn::LayerType::Reshape, nodeIndex);
211 armnn::IConnectableLayer* layer = delegateData.m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
Cathal Corbett53837672022-09-01 11:34:37 +0100212 layer->SetBackendId(setBackend);
David Monahan1670b0c2020-11-18 14:40:27 +0000213 ARMNN_ASSERT(layer != nullptr);
214
215 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
216 outputSlot.SetTensorInfo(outputTensorInfo);
217
Ryan OShea4c231de2023-01-17 15:19:20 +0000218 // try to connect the Constant Inputs if there are any
Mike Kelly07169c82023-08-02 13:23:09 +0100219 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Ryan OShea4c231de2023-01-17 15:19:20 +0000220 {
221 return kTfLiteError;
222 }
223
David Monahan1670b0c2020-11-18 14:40:27 +0000224 // Connect
225 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100226}
227
228TfLiteStatus VisitSqueezeOperator(DelegateData& delegateData,
229 TfLiteContext* tfLiteContext,
230 TfLiteNode* tfLiteNode,
231 int nodeIndex,
232 int32_t operatorCode)
233{
Matthew Sloyan3504e422023-05-03 13:53:02 +0100234 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
235 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +0000236
Matthew Sloyan3504e422023-05-03 13:53:02 +0100237 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
238 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
239 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
240 {
241 return kTfLiteError;
242 }
243
244 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
245 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
246 {
247 return kTfLiteError;
248 }
249
250 auto* options = reinterpret_cast<TfLiteSqueezeParams*>(tfLiteNode->builtin_data);
251
252 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
253
254 std::vector<uint32_t> squeezeDim;
255 // A single negative dim index is interpreted as a negative index in python
256 // Meaning the index will be the shape size plus the negative index value
257 if (options->num_squeeze_dims == 1 && options->squeeze_dims[0] < 0)
258 {
259 int32_t dim = static_cast<int32_t>(inputTensorInfo.GetShape().GetNumDimensions()) + options->squeeze_dims[0];
260 squeezeDim.push_back(static_cast<uint32_t>(dim));
261 }
262 else
263 {
264 for (int32_t i = 0; i < options->num_squeeze_dims; ++i)
265 {
266 squeezeDim.push_back(static_cast<uint32_t>(options->squeeze_dims[i]));
267 }
268 }
269
270 armnn::TensorInfo outputTensorInfo = OutputShapeOfSqueeze(squeezeDim, inputTensorInfo);
271
272 armnn::ReshapeDescriptor reshapeDesc;
273 reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
274
275 bool isSupported = false;
276 armnn::BackendId setBackend;
277 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
278 {
279 FORWARD_LAYER_SUPPORT_FUNC("SQUEEZE",
280 tfLiteContext,
281 IsReshapeSupported,
282 delegateData.m_Backends,
283 isSupported,
284 setBackend,
285 inputTensorInfo,
286 outInfo,
287 reshapeDesc);
288 };
289
290 if (!delegateData.m_Network)
291 {
292 validateFunc(outputTensorInfo, isSupported);
293 return isSupported ? kTfLiteOk : kTfLiteError;
294 }
295
Mike Kelly07169c82023-08-02 13:23:09 +0100296 auto layerName = GetLayerName(armnn::LayerType::Reshape, nodeIndex, "Squeeze");
297 armnn::IConnectableLayer* layer = delegateData.m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
Matthew Sloyan3504e422023-05-03 13:53:02 +0100298 layer->SetBackendId(setBackend);
299 ARMNN_ASSERT(layer != nullptr);
300
301 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
302 outputSlot.SetTensorInfo(outputTensorInfo);
303
304 // try to connect the Constant Inputs if there are any
Mike Kelly07169c82023-08-02 13:23:09 +0100305 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Matthew Sloyan3504e422023-05-03 13:53:02 +0100306 {
307 return kTfLiteError;
308 }
309
310 // Connect
311 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100312}
313
314TfLiteStatus VisitExpandDimsOperator(DelegateData& delegateData,
315 TfLiteContext* tfLiteContext,
316 TfLiteNode* tfLiteNode,
317 int nodeIndex,
318 int32_t operatorCode)
319{
Matthew Sloyan3504e422023-05-03 13:53:02 +0100320 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
321 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +0000322
Matthew Sloyan3504e422023-05-03 13:53:02 +0100323 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
324 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
325 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
326 {
327 return kTfLiteError;
328 }
329
330 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
331 if (!IsValid(tfLiteContext, tfLiteAxisTensor, operatorCode, nodeIndex))
332 {
333 return kTfLiteError;
334 }
335
336 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
337 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
338 {
339 return kTfLiteError;
340 }
341
342 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
343 armnn::TensorInfo outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
344
345 auto* axisTensorData = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
346 int32_t axis = axisTensorData[0];
347
348 int32_t inputDimSize = static_cast<int32_t>(inputTensorInfo.GetShape().GetNumDimensions());
349 if (axis > inputDimSize || axis < 0 - (inputDimSize + 1))
350 {
351 TF_LITE_MAYBE_KERNEL_LOG(
352 tfLiteContext,
353 "TfLiteArmnnOpaqueDelegate: Axis must be in range "
354 "[0 - (inputDimSize + 1), inputDimSize] inclusive.");
355 return kTfLiteError;
356 }
357
358 if(axis < 0)
359 {
360 axis = inputDimSize + axis + 1;
361 }
362
363 std::vector<unsigned int> shape(static_cast<unsigned int>(inputDimSize) + 1);
364 unsigned int inputShapeIndex = 0;
365 for (unsigned int i = 0; i < static_cast<unsigned int>(inputDimSize + 1); ++i)
366 {
367 if (i == static_cast<unsigned int>(axis))
368 {
369 shape[i] = 1;
370 }
371 else
372 {
373 shape[i] = inputTensorInfo.GetShape()[inputShapeIndex];
374 ++inputShapeIndex;
375 }
376 }
377
378 armnn::ReshapeDescriptor reshapeDesc;
379 reshapeDesc.m_TargetShape = armnn::TensorShape(static_cast<unsigned int>(inputDimSize + 1), shape.data());
380
381 bool isSupported = false;
382 armnn::BackendId setBackend;
383 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
384 {
385 FORWARD_LAYER_SUPPORT_FUNC("EXPAND_DIMS",
386 tfLiteContext,
387 IsReshapeSupported,
388 delegateData.m_Backends,
389 isSupported,
390 setBackend,
391 inputTensorInfo,
392 outInfo,
393 reshapeDesc);
394 };
395
396 if (!delegateData.m_Network)
397 {
398 validateFunc(outputTensorInfo, isSupported);
399 return isSupported ? kTfLiteOk : kTfLiteError;
400 }
401
Mike Kelly07169c82023-08-02 13:23:09 +0100402 auto layerName = GetLayerName(armnn::LayerType::Reshape, nodeIndex, "ExpandDims");
403 armnn::IConnectableLayer* layer = delegateData.m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
Matthew Sloyan3504e422023-05-03 13:53:02 +0100404 layer->SetBackendId(setBackend);
405 ARMNN_ASSERT(layer != nullptr);
406
407 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
408 outputTensorInfo.SetShape(reshapeDesc.m_TargetShape);
409 outputSlot.SetTensorInfo(outputTensorInfo);
410
411 // try to connect the Constant Inputs if there are any
Mike Kelly07169c82023-08-02 13:23:09 +0100412 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Matthew Sloyan3504e422023-05-03 13:53:02 +0100413 {
414 return kTfLiteError;
415 }
416
417 // Connect
418 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100419}
420
421} // namespace armnnDelegate