blob: 00e270558e8576ce698b06db0595175ca3bf5a2f [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Kevin May81b66f32023-04-26 14:55:36 +01005
6#include <OpaqueDelegateUtils.hpp>
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/minimal_logging.h>
12
13namespace armnnOpaqueDelegate
14{
15
16TfLiteStatus VisitBatchToSpaceNdOperator(DelegateData& delegateData,
17 TfLiteOpaqueContext* tfLiteContext,
18 TfLiteOpaqueNode* tfLiteNode,
19 int nodeIndex,
20 int32_t operatorCode)
21{
22 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
23 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
24
25 int numInputs = 3;
26 const int* inputTensors;
27 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
28 {
29 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
30 tfLiteContext,
31 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
32 nodeIndex);
33 return kTfLiteError;
34 }
35
36 int numOutputs = 0;
37 const int* outputTensors;
38 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
39 {
40 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
41 tfLiteContext,
42 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
43 nodeIndex);
44 return kTfLiteError;
45 }
46
47 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
48 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
49 {
50 return kTfLiteError;
51 }
52 const TfLiteOpaqueTensor* tfLiteBlockShapeTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
53 inputTensors[1]);
54 if (!IsValid(tfLiteContext, tfLiteBlockShapeTensor, operatorCode, nodeIndex))
55 {
56 return kTfLiteError;
57 }
58 const TfLiteOpaqueTensor* tfLiteCropsTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[2]);
59 if (!IsValid(tfLiteContext, tfLiteCropsTensor, operatorCode, nodeIndex))
60 {
61 return kTfLiteError;
62 }
63 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
64 outputTensors[0]);
65 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
66 {
67 return kTfLiteError;
68 }
69
70 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
71 const armnn::TensorInfo& blockShapeTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteBlockShapeTensor);
72 const armnn::TensorInfo& cropsTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteCropsTensor);
73 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
74
75
76 // Copy memory into block and crops
77 std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements());
78 ::memcpy(blockShape.data(), TfLiteOpaqueTensorData(tfLiteBlockShapeTensor), blockShapeTensorInfo.GetNumBytes());
79
80 std::vector<unsigned int> cropsVector(cropsTensorInfo.GetNumElements());
81 std::memcpy(cropsVector.data(), TfLiteOpaqueTensorData(tfLiteCropsTensor), cropsTensorInfo.GetNumBytes());
82
83 size_t step = 2;
84 std::vector<std::pair<unsigned int, unsigned int>> crops;
85 for (unsigned int i = 0; i < cropsTensorInfo.GetNumElements() / step; ++i)
86 {
87 crops.emplace_back(cropsVector[i * step], cropsVector[i * step + 1]);
88 }
89
90 // Make a descriptor
91 armnn::BatchToSpaceNdDescriptor descriptor;
92 descriptor.m_BlockShape = blockShape;
93 descriptor.m_Crops = crops;
94 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
95
96 // Check if supported
97 bool isSupported = false;
98 armnn::BackendId setBackend;
99 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
100 {
101 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("BATCH_TO_SPACE_ND",
102 tfLiteContext,
103 IsBatchToSpaceNdSupported,
104 delegateData.m_Backends,
105 isSupported,
106 setBackend,
107 inputTensorInfo,
108 outputTensorInfo,
109 descriptor);
110 };
111
112 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
113 // support for the operator
114 // If supported, VisitBatchToSpaceNdOperator will be called again to add the layer to the network as seen below
115 if (!delegateData.m_Network)
116 {
117 validateFunc(outputTensorInfo, isSupported);
118 return isSupported ? kTfLiteOk : kTfLiteError;
119 }
120
121 // Add a BatchToSpace layer
Mike Kellya2806502023-08-03 10:42:11 +0100122 auto layerName = GetName(armnn::LayerType::BatchToSpaceNd, nodeIndex);
123 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchToSpaceNdLayer(descriptor, layerName.c_str());
Kevin May81b66f32023-04-26 14:55:36 +0100124 layer->SetBackendId(setBackend);
125 ARMNN_ASSERT(layer != nullptr);
126
127 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
128 outputSlot.SetTensorInfo(outputTensorInfo);
129
130 // try to connect the Constant Inputs if there are any
Mike Kellya2806502023-08-03 10:42:11 +0100131 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Kevin May81b66f32023-04-26 14:55:36 +0100132 {
133 return kTfLiteError;
134 }
135
136 // Connect
137 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
138}
139
140TfLiteStatus VisitSpaceToBatchNdOperator(DelegateData& delegateData,
141 TfLiteOpaqueContext* tfLiteContext,
142 TfLiteOpaqueNode* tfLiteNode,
143 int nodeIndex,
144 int32_t operatorCode)
145{
146 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
147 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
148
149 int numInputs = 3;
150 const int* inputTensors;
151 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
152 {
153 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
154 tfLiteContext,
155 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
156 nodeIndex);
157 return kTfLiteError;
158 }
159
160 int numOutputs = 0;
161 const int* outputTensors;
162 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
163 {
164 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
165 tfLiteContext,
166 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
167 nodeIndex);
168 return kTfLiteError;
169 }
170
171 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,inputTensors[0]);
172 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
173 {
174 return kTfLiteError;
175 }
176 const TfLiteOpaqueTensor* tfLiteBlockShapeTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
177 inputTensors[1]);
178 if (!IsValid(tfLiteContext, tfLiteBlockShapeTensor, operatorCode, nodeIndex))
179 {
180 return kTfLiteError;
181 }
182 const TfLiteOpaqueTensor* tfLitePadListTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
183 inputTensors[2]);
184 if (!IsValid(tfLiteContext, tfLitePadListTensor, operatorCode, nodeIndex))
185 {
186 return kTfLiteError;
187 }
188 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
189 outputTensors[0]);
190 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
191 {
192 return kTfLiteError;
193 }
194
195 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
196 const armnn::TensorInfo& blockShapeTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteBlockShapeTensor);
197 const armnn::TensorInfo& padListTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLitePadListTensor);
198 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
199
200 std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements());
201 std::memcpy(blockShape.data(),
202 TfLiteOpaqueTensorData(tfLiteBlockShapeTensor),
203 blockShapeTensorInfo.GetNumBytes());
204
205 std::vector<unsigned int> padListVector(padListTensorInfo.GetNumElements());
206 std::memcpy(padListVector.data(),
207 TfLiteOpaqueTensorData(tfLitePadListTensor),
208 padListTensorInfo.GetNumBytes());
209
210 size_t step = 2;
211 std::vector<std::pair<unsigned int, unsigned int>> padList;
212 for (unsigned int i = 0; i < padListTensorInfo.GetNumElements() / step; ++i)
213 {
214 padList.emplace_back(padListVector[i * step], padListVector[i * step + 1]);
215 }
216
217 armnn::SpaceToBatchNdDescriptor descriptor;
218 descriptor.m_BlockShape = blockShape;
219 descriptor.m_PadList = padList;
220 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
221
222 // Check if supported
223 bool isSupported = false;
224 armnn::BackendId setBackend;
225 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
226 {
227 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("SPACE_TO_BATCH_ND",
228 tfLiteContext,
229 IsSpaceToBatchNdSupported,
230 delegateData.m_Backends,
231 isSupported,
232 setBackend,
233 inputTensorInfo,
234 outputTensorInfo,
235 descriptor);
236 };
237
238 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
239 // support for the operator
240 // If supported, VisitSpaceToBatchNdOperator will be called again to add the layer to the network as seen below
241 if (!delegateData.m_Network)
242 {
243 validateFunc(outputTensorInfo, isSupported);
244 return isSupported ? kTfLiteOk : kTfLiteError;
245 }
246
247 // Add a SpaceToBatch layer
Mike Kellya2806502023-08-03 10:42:11 +0100248 auto layerName = GetName(armnn::LayerType::SpaceToBatchNd, nodeIndex);
249 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSpaceToBatchNdLayer(descriptor, layerName.c_str());
Kevin May81b66f32023-04-26 14:55:36 +0100250 layer->SetBackendId(setBackend);
251 ARMNN_ASSERT(layer != nullptr);
252
253 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
254 outputSlot.SetTensorInfo(outputTensorInfo);
255
256 // try to connect the Constant Inputs if there are any
Mike Kellya2806502023-08-03 10:42:11 +0100257 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Kevin May81b66f32023-04-26 14:55:36 +0100258 {
259 return kTfLiteError;
260 }
261
262 // Connect
263 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
264}
265
266} // namespace