blob: 3df26cacc319643c1feafb4686f8c0a9b57c8cf3 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Finn Williams6f9f9902020-11-13 13:23:15 +00008#include <armnn/utility/IgnoreUnused.hpp>
9
David Monahan1670b0c2020-11-18 14:40:27 +000010#include "DelegateUtils.hpp"
11
Sadik Armagan62483be2020-10-23 17:14:43 +010012#include <tensorflow/lite/builtin_ops.h>
13#include <tensorflow/lite/c/builtin_op_data.h>
14#include <tensorflow/lite/c/common.h>
15#include <tensorflow/lite/minimal_logging.h>
David Monahan1670b0c2020-11-18 14:40:27 +000016#include <numeric>
Sadik Armagan62483be2020-10-23 17:14:43 +010017
18namespace armnnDelegate
19{
20
David Monahan1670b0c2020-11-18 14:40:27 +000021TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo,
Matthew Sloyanf00f6c22020-12-07 13:33:24 +000022 const std::vector<int32_t>& targetShape,
23 armnn::ReshapeDescriptor& reshapeDesc)
David Monahan1670b0c2020-11-18 14:40:27 +000024{
25 std::vector<unsigned int> outputDims(targetShape.begin(), targetShape.end());
26 const auto stretchDim = std::find(targetShape.begin(), targetShape.end(), -1);
27
28 if (stretchDim != targetShape.end())
29 {
30 if (std::find(std::next(stretchDim), targetShape.end(), -1) != targetShape.end())
31 {
32 // Return kTfLiteError and log the error after returning
33 return kTfLiteError;
34 }
35
36 auto targetNumElements =
37 armnn::numeric_cast<unsigned int>(
38 std::accumulate(targetShape.begin(), targetShape.end(), -1, std::multiplies<int32_t>()));
39
40 auto stretchIndex = static_cast<size_t>(std::distance(targetShape.begin(), stretchDim));
41 outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
42 }
43
44 armnn::TensorShape outputShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()),
45 outputDims.data());
46 reshapeDesc.m_TargetShape = outputShape;
47 return kTfLiteOk;
48}
49
Sadik Armagan62483be2020-10-23 17:14:43 +010050TfLiteStatus VisitReshapeOperator(DelegateData& delegateData,
51 TfLiteContext* tfLiteContext,
52 TfLiteNode* tfLiteNode,
53 int nodeIndex,
54 int32_t operatorCode)
55{
David Monahan1670b0c2020-11-18 14:40:27 +000056 auto numInputs = tfLiteNode->inputs->size;
Finn Williams6f9f9902020-11-13 13:23:15 +000057
David Monahan1670b0c2020-11-18 14:40:27 +000058 if (numInputs == 2)
59 {
60 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
61 }
62 else
63 {
64 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
65 }
66 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
67
68 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
69 const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
Matthew Sloyanf00f6c22020-12-07 13:33:24 +000070 if (!IsValid(tfLiteContext, tfLiteInputTensor0, operatorCode, nodeIndex))
David Monahan1670b0c2020-11-18 14:40:27 +000071 {
David Monahan1670b0c2020-11-18 14:40:27 +000072 return kTfLiteError;
73 }
74
75 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
Matthew Sloyanf00f6c22020-12-07 13:33:24 +000076 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
David Monahan1670b0c2020-11-18 14:40:27 +000077 {
David Monahan1670b0c2020-11-18 14:40:27 +000078 return kTfLiteError;
79 }
80
81 const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
82 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
83
84 armnn::ReshapeDescriptor reshapeDesc;
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +000085 std::vector<int32_t> targetShape;
Finn Williamsf806c4d2021-02-22 15:13:12 +000086
87 TfLiteReshapeParams* reshapeOptions = reinterpret_cast<TfLiteReshapeParams*>(tfLiteNode->builtin_data);
David Monahan1670b0c2020-11-18 14:40:27 +000088
89 // The new shape can be defined by either a second input tensor or by a builtin option, we need to check for both.
Finn Williamsf806c4d2021-02-22 15:13:12 +000090 // Options might be set without valid data. we need to check the dimensions are in a valid range.
91 if (reshapeOptions && reshapeOptions->num_dimensions > 0 && reshapeOptions->num_dimensions <= 8)
92 {
93 for (int i=0; i < reshapeOptions->num_dimensions; ++i)
94 {
95 targetShape.push_back(reshapeOptions->shape[i]);
96 }
97 }
98 else if (numInputs == 2)
David Monahan1670b0c2020-11-18 14:40:27 +000099 {
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000100 // Get shape from the second input tensor
101 const TfLiteTensor& tfLiteShapeInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000102 if (!IsValid(tfLiteContext, tfLiteShapeInputTensor, operatorCode, nodeIndex))
David Monahane03d9c22020-11-20 09:58:54 +0000103 {
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000104 return kTfLiteError;
105 }
106
107 if (tfLiteShapeInputTensor.dims->size != 1)
108 {
109 TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
110 "TfLiteArmnnDelegate: Target 'shape' input is not a 1D tensor in "
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000111 "operator #%d node #%d: Falling back to TfLiteOptions.",
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000112 operatorCode, nodeIndex);
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000113 }
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000114 else
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000115 {
Matthew Sloyanf00f6c22020-12-07 13:33:24 +0000116 // Get the shape data out of the input tensor
117 auto* shapeTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteShapeInputTensor);
118 auto shapeTensorNumValues = tfLiteShapeInputTensor.dims->data[0];
119 for (auto i=0; i < shapeTensorNumValues; ++i)
120 {
121 targetShape.push_back(*(shapeTensorDataPtr+i));
122 }
David Monahane03d9c22020-11-20 09:58:54 +0000123 }
124 }
Finn Williamsf806c4d2021-02-22 15:13:12 +0000125 else
David Monahane03d9c22020-11-20 09:58:54 +0000126 {
Finn Williamsf806c4d2021-02-22 15:13:12 +0000127 TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
128 "Target shape not defined in reshape parameters or input tensor. "
129 "At least one method required in operator #%d node #%d: ",
130 operatorCode, nodeIndex);
131 return kTfLiteError;
David Monahan1670b0c2020-11-18 14:40:27 +0000132 }
David Monahane03d9c22020-11-20 09:58:54 +0000133
134 // Use the data to create the required tensor shape.
135 if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk)
David Monahan1670b0c2020-11-18 14:40:27 +0000136 {
137 TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
David Monahane03d9c22020-11-20 09:58:54 +0000138 "TfLiteArmnnDelegate: At most one component of shape can be -1 in: "
139 "operator #%d node #%d: ",
David Monahan1670b0c2020-11-18 14:40:27 +0000140 operatorCode, nodeIndex);
David Monahane03d9c22020-11-20 09:58:54 +0000141 return kTfLiteError;
142 }
143
144 if (reshapeDesc.m_TargetShape.GetNumElements() != inputTensorInfo0.GetNumElements())
145 {
Narumol Prangnawarat7f6c6672020-11-24 18:40:42 +0000146 TF_LITE_MAYBE_KERNEL_LOG(
147 tfLiteContext,
148 "TfLiteArmnnDelegate: Reshape, number of elements in output shape does not match input "
149 "operator #%d node #%d: ",
150 operatorCode, nodeIndex);
David Monahane03d9c22020-11-20 09:58:54 +0000151 return kTfLiteError;
David Monahan1670b0c2020-11-18 14:40:27 +0000152 }
153
154 bool isSupported = false;
155 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
156 {
157 FORWARD_LAYER_SUPPORT_FUNC(__func__,
158 tfLiteContext,
159 IsReshapeSupported,
160 delegateData.m_Backends,
161 isSupported,
162 inputTensorInfo0,
163 outInfo,
164 reshapeDesc);
165 };
166
167 if (!delegateData.m_Network)
168 {
169 validateFunc(outputTensorInfo, isSupported);
170 return isSupported ? kTfLiteOk : kTfLiteError;
171 }
172
173 armnn::IConnectableLayer* layer = delegateData.m_Network->AddReshapeLayer(reshapeDesc);
174 ARMNN_ASSERT(layer != nullptr);
175
176 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
177 outputSlot.SetTensorInfo(outputTensorInfo);
178
179 // Connect
180 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100181}
182
183TfLiteStatus VisitSqueezeOperator(DelegateData& delegateData,
184 TfLiteContext* tfLiteContext,
185 TfLiteNode* tfLiteNode,
186 int nodeIndex,
187 int32_t operatorCode)
188{
Finn Williams6f9f9902020-11-13 13:23:15 +0000189 armnn::IgnoreUnused(delegateData,
190 tfLiteContext,
191 tfLiteNode,
192 nodeIndex,
193 operatorCode);
194
Sadik Armagan62483be2020-10-23 17:14:43 +0100195 return kTfLiteError;
196}
197
198TfLiteStatus VisitExpandDimsOperator(DelegateData& delegateData,
199 TfLiteContext* tfLiteContext,
200 TfLiteNode* tfLiteNode,
201 int nodeIndex,
202 int32_t operatorCode)
203{
Finn Williams6f9f9902020-11-13 13:23:15 +0000204 armnn::IgnoreUnused(delegateData,
205 tfLiteContext,
206 tfLiteNode,
207 nodeIndex,
208 operatorCode);
209
Sadik Armagan62483be2020-10-23 17:14:43 +0100210 return kTfLiteError;
211}
212
213} // namespace armnnDelegate