blob: ce25f7f18bc0b9c159cfc5e86c5d5a08d105bd4f [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Finn Williams6f9f9902020-11-13 13:23:15 +00008#include <armnn/utility/IgnoreUnused.hpp>
9
David Monahan1670b0c2020-11-18 14:40:27 +000010#include "DelegateUtils.hpp"
11
Sadik Armagan62483be2020-10-23 17:14:43 +010012#include <tensorflow/lite/builtin_ops.h>
13#include <tensorflow/lite/c/builtin_op_data.h>
14#include <tensorflow/lite/c/common.h>
15#include <tensorflow/lite/minimal_logging.h>
David Monahan1670b0c2020-11-18 14:40:27 +000016#include <numeric>
Sadik Armagan62483be2020-10-23 17:14:43 +010017
18namespace armnnDelegate
19{
20
Sadik Armagan937565b2021-04-21 14:03:28 +010021TfLiteStatus VisitCastOperator(DelegateData& delegateData,
22 TfLiteContext* tfLiteContext,
23 TfLiteNode* tfLiteNode,
24 int nodeIndex,
25 int32_t operatorCode)
26{
27 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
28 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
29
30 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
31 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
32 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
33 {
34 return kTfLiteError;
35 }
36
37 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
38 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
39 {
40 return kTfLiteError;
41 }
42
43 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
44 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
45
46 bool isSupported = false;
47 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
48 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000049 FORWARD_LAYER_SUPPORT_FUNC("CAST",
Sadik Armagan937565b2021-04-21 14:03:28 +010050 tfLiteContext,
51 IsCastSupported,
52 delegateData.m_Backends,
53 isSupported,
54 inputTensorInfo,
55 outInfo);
56 };
57
58 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
59 // support for the operator
60 // If supported, VisitCastOperator will be called again to add the layer to the network as seen further below
61 if (!delegateData.m_Network)
62 {
63 validateFunc(outputTensorInfo, isSupported);
64 return isSupported ? kTfLiteOk : kTfLiteError;
65 }
66
67 // Add a Cast layer
68 armnn::IConnectableLayer* layer = delegateData.m_Network->AddCastLayer();
69 ARMNN_ASSERT(layer != nullptr);
70
71 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
72 outputSlot.SetTensorInfo(outputTensorInfo);
73
74 // Connect
75 return Connect(layer, tfLiteNode, delegateData);
76}
77
78
David Monahan1670b0c2020-11-18 14:40:27 +000079TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo,
Matthew Sloyanf00f6c22020-12-07 13:33:24 +000080 const std::vector<int32_t>& targetShape,
81 armnn::ReshapeDescriptor& reshapeDesc)
David Monahan1670b0c2020-11-18 14:40:27 +000082{
83 std::vector<unsigned int> outputDims(targetShape.begin(), targetShape.end());
84 const auto stretchDim = std::find(targetShape.begin(), targetShape.end(), -1);
85
86 if (stretchDim != targetShape.end())
87 {
88 if (std::find(std::next(stretchDim), targetShape.end(), -1) != targetShape.end())
89 {
90 // Return kTfLiteError and log the error after returning
91 return kTfLiteError;
92 }
93
94 auto targetNumElements =
95 armnn::numeric_cast<unsigned int>(
96 std::accumulate(targetShape.begin(), targetShape.end(), -1, std::multiplies<int32_t>()));
97
98 auto stretchIndex = static_cast<size_t>(std::distance(targetShape.begin(), stretchDim));
99 outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
100 }
101
102 armnn::TensorShape outputShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()),
103 outputDims.data());
104 reshapeDesc.m_TargetShape = outputShape;
105 return kTfLiteOk;
106}
107
Sadik Armagan62483be2020-10-23 17:14:43 +0100108TfLiteStatus VisitReshapeOperator(DelegateData& delegateData,
109 TfLiteContext* tfLiteContext,
110 TfLiteNode* tfLiteNode,
111 int nodeIndex,
112 int32_t operatorCode)
113{
David Monahan1670b0c2020-11-18 14:40:27 +0000114 auto numInputs = tfLiteNode->inputs->size;
Finn Williams6f9f9902020-11-13 13:23:15 +0000115
David Monahan1670b0c2020-11-18 14:40:27 +0000116 if (numInputs == 2)
117 {
118 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
119 }
120 else
121 {
122 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
123 }
124 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
125
126 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
127 const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000128 if (!IsValid(tfLiteContext, tfLiteInputTensor0, operatorCode, nodeIndex))
David Monahan1670b0c2020-11-18 14:40:27 +0000129 {
David Monahan1670b0c2020-11-18 14:40:27 +0000130 return kTfLiteError;
131 }
132
133 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000134 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
David Monahan1670b0c2020-11-18 14:40:27 +0000135 {
David Monahan1670b0c2020-11-18 14:40:27 +0000136 return kTfLiteError;
137 }
138
139 const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
140 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
141
142 armnn::ReshapeDescriptor reshapeDesc;
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000143 std::vector<int32_t> targetShape;
Finn Williamsf806c4d2021-02-22 15:13:12 +0000144
145 TfLiteReshapeParams* reshapeOptions = reinterpret_cast<TfLiteReshapeParams*>(tfLiteNode->builtin_data);
David Monahan1670b0c2020-11-18 14:40:27 +0000146
147 // The new shape can be defined by either a second input tensor or by a builtin option, we need to check for both.
Finn Williamsf806c4d2021-02-22 15:13:12 +0000148 // Options might be set without valid data. we need to check the dimensions are in a valid range.
149 if (reshapeOptions && reshapeOptions->num_dimensions > 0 && reshapeOptions->num_dimensions <= 8)
150 {
151 for (int i=0; i < reshapeOptions->num_dimensions; ++i)
152 {
153 targetShape.push_back(reshapeOptions->shape[i]);
154 }
155 }
156 else if (numInputs == 2)
David Monahan1670b0c2020-11-18 14:40:27 +0000157 {
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000158 // Get shape from the second input tensor
159 const TfLiteTensor& tfLiteShapeInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000160 if (!IsValid(tfLiteContext, tfLiteShapeInputTensor, operatorCode, nodeIndex))
David Monahane03d9c22020-11-20 09:58:54 +0000161 {
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000162 return kTfLiteError;
163 }
164
165 if (tfLiteShapeInputTensor.dims->size != 1)
166 {
167 TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
168 "TfLiteArmnnDelegate: Target 'shape' input is not a 1D tensor in "
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000169 "operator #%d node #%d: Falling back to TfLiteOptions.",
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000170 operatorCode, nodeIndex);
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000171 }
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000172 else
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000173 {
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000174 // Get the shape data out of the input tensor
175 auto* shapeTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteShapeInputTensor);
176 auto shapeTensorNumValues = tfLiteShapeInputTensor.dims->data[0];
177 for (auto i=0; i < shapeTensorNumValues; ++i)
178 {
179 targetShape.push_back(*(shapeTensorDataPtr+i));
180 }
David Monahane03d9c22020-11-20 09:58:54 +0000181 }
182 }
Finn Williamsf806c4d2021-02-22 15:13:12 +0000183 else
David Monahane03d9c22020-11-20 09:58:54 +0000184 {
Finn Williamsf806c4d2021-02-22 15:13:12 +0000185 TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
186 "Target shape not defined in reshape parameters or input tensor. "
187 "At least one method required in operator #%d node #%d: ",
188 operatorCode, nodeIndex);
189 return kTfLiteError;
David Monahan1670b0c2020-11-18 14:40:27 +0000190 }
David Monahane03d9c22020-11-20 09:58:54 +0000191
192 // Use the data to create the required tensor shape.
193 if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk)
David Monahan1670b0c2020-11-18 14:40:27 +0000194 {
195 TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
David Monahane03d9c22020-11-20 09:58:54 +0000196 "TfLiteArmnnDelegate: At most one component of shape can be -1 in: "
197 "operator #%d node #%d: ",
David Monahan1670b0c2020-11-18 14:40:27 +0000198 operatorCode, nodeIndex);
David Monahane03d9c22020-11-20 09:58:54 +0000199 return kTfLiteError;
200 }
201
202 if (reshapeDesc.m_TargetShape.GetNumElements() != inputTensorInfo0.GetNumElements())
203 {
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000204 TF_LITE_MAYBE_KERNEL_LOG(
205 tfLiteContext,
206 "TfLiteArmnnDelegate: Reshape, number of elements in output shape does not match input "
207 "operator #%d node #%d: ",
208 operatorCode, nodeIndex);
David Monahane03d9c22020-11-20 09:58:54 +0000209 return kTfLiteError;
David Monahan1670b0c2020-11-18 14:40:27 +0000210 }
211
212 bool isSupported = false;
213 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
214 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000215 FORWARD_LAYER_SUPPORT_FUNC("RESHAPE",
David Monahan1670b0c2020-11-18 14:40:27 +0000216 tfLiteContext,
217 IsReshapeSupported,
218 delegateData.m_Backends,
219 isSupported,
220 inputTensorInfo0,
221 outInfo,
222 reshapeDesc);
223 };
224
225 if (!delegateData.m_Network)
226 {
227 validateFunc(outputTensorInfo, isSupported);
228 return isSupported ? kTfLiteOk : kTfLiteError;
229 }
230
231 armnn::IConnectableLayer* layer = delegateData.m_Network->AddReshapeLayer(reshapeDesc);
232 ARMNN_ASSERT(layer != nullptr);
233
234 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
235 outputSlot.SetTensorInfo(outputTensorInfo);
236
237 // Connect
238 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100239}
240
241TfLiteStatus VisitSqueezeOperator(DelegateData& delegateData,
242 TfLiteContext* tfLiteContext,
243 TfLiteNode* tfLiteNode,
244 int nodeIndex,
245 int32_t operatorCode)
246{
Finn Williams6f9f9902020-11-13 13:23:15 +0000247 armnn::IgnoreUnused(delegateData,
248 tfLiteContext,
249 tfLiteNode,
250 nodeIndex,
251 operatorCode);
252
Sadik Armagan62483be2020-10-23 17:14:43 +0100253 return kTfLiteError;
254}
255
256TfLiteStatus VisitExpandDimsOperator(DelegateData& delegateData,
257 TfLiteContext* tfLiteContext,
258 TfLiteNode* tfLiteNode,
259 int nodeIndex,
260 int32_t operatorCode)
261{
Finn Williams6f9f9902020-11-13 13:23:15 +0000262 armnn::IgnoreUnused(delegateData,
263 tfLiteContext,
264 tfLiteNode,
265 nodeIndex,
266 operatorCode);
267
Sadik Armagan62483be2020-10-23 17:14:43 +0100268 return kTfLiteError;
269}
270
271} // namespace armnnDelegate