blob: 07491cee0d855ba4fde5cdbde8e19d2480b1683b [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/minimal_logging.h>
12
13namespace armnnDelegate
14{
15
16TfLiteStatus VisitBatchToSpaceNdOperator(DelegateData& delegateData,
17 TfLiteContext* tfLiteContext,
18 TfLiteNode* tfLiteNode,
19 int nodeIndex,
20 int32_t operatorCode)
21{
Matthew Sloyana35b40b2021-02-05 17:22:28 +000022 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
23 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
24
25 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
26 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
27 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
28 {
29 return kTfLiteError;
30 }
31
32 const TfLiteTensor& tfLiteBlockShapeTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
33 if (!IsValid(tfLiteContext, tfLiteBlockShapeTensor, operatorCode, nodeIndex))
34 {
35 return kTfLiteError;
36 }
37
38 const TfLiteTensor& tfLiteCropsTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
39 if (!IsValid(tfLiteContext, tfLiteCropsTensor, operatorCode, nodeIndex))
40 {
41 return kTfLiteError;
42 }
43
44 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
45 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
46 {
47 return kTfLiteError;
48 }
49
50 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
51 const armnn::TensorInfo& blockShapeTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBlockShapeTensor);
52 const armnn::TensorInfo& cropsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteCropsTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010053 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Matthew Sloyana35b40b2021-02-05 17:22:28 +000054
55 std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements());
56 ::memcpy(blockShape.data(), tfLiteBlockShapeTensor.data.data, blockShapeTensorInfo.GetNumBytes());
57
58 std::vector<unsigned int> cropsVector(cropsTensorInfo.GetNumElements());
59 std::memcpy(cropsVector.data(), tfLiteCropsTensor.data.data, cropsTensorInfo.GetNumBytes());
60
61 size_t step = 2;
62 std::vector<std::pair<unsigned int, unsigned int>> crops;
63 for (unsigned int i = 0; i < cropsTensorInfo.GetNumElements() / step; ++i)
64 {
65 crops.emplace_back(cropsVector[i * step], cropsVector[i * step + 1]);
66 }
67
68 armnn::BatchToSpaceNdDescriptor descriptor;
69 descriptor.m_BlockShape = blockShape;
70 descriptor.m_Crops = crops;
71 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
72
73 // Check if supported
74 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +010075 armnn::BackendId setBackend;
Matthew Sloyana35b40b2021-02-05 17:22:28 +000076 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
77 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000078 FORWARD_LAYER_SUPPORT_FUNC("BATCH_TO_SPACE_ND",
Matthew Sloyana35b40b2021-02-05 17:22:28 +000079 tfLiteContext,
80 IsBatchToSpaceNdSupported,
81 delegateData.m_Backends,
82 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010083 setBackend,
Matthew Sloyana35b40b2021-02-05 17:22:28 +000084 inputTensorInfo,
85 outputTensorInfo,
86 descriptor);
87 };
88
89 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
90 // support for the operator
91 // If supported, VisitBatchToSpaceNdOperator will be called again to add the layer to the network as seen below
92 if (!delegateData.m_Network)
93 {
94 validateFunc(outputTensorInfo, isSupported);
95 return isSupported ? kTfLiteOk : kTfLiteError;
96 }
97
98 // Add a BatchToSpace layer
Mike Kelly07169c82023-08-02 13:23:09 +010099 auto layerName = GetLayerName(armnn::LayerType::BatchToSpaceNd, nodeIndex);
100 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchToSpaceNdLayer(descriptor, layerName.c_str());
Cathal Corbett53837672022-09-01 11:34:37 +0100101 layer->SetBackendId(setBackend);
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000102 ARMNN_ASSERT(layer != nullptr);
103
104 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
105 outputSlot.SetTensorInfo(outputTensorInfo);
106
Ryan OShea4c231de2023-01-17 15:19:20 +0000107 // try to connect the Constant Inputs if there are any
Mike Kelly07169c82023-08-02 13:23:09 +0100108 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Ryan OShea4c231de2023-01-17 15:19:20 +0000109 {
110 return kTfLiteError;
111 }
112
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000113 // Connect
114 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100115}
116
117TfLiteStatus VisitSpaceToBatchNdOperator(DelegateData& delegateData,
118 TfLiteContext* tfLiteContext,
119 TfLiteNode* tfLiteNode,
120 int nodeIndex,
121 int32_t operatorCode)
122{
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000123 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
124 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
125
126 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
127 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
128 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
129 {
130 return kTfLiteError;
131 }
132
133 const TfLiteTensor& tfLiteBlockShapeTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
134 if (!IsValid(tfLiteContext, tfLiteBlockShapeTensor, operatorCode, nodeIndex))
135 {
136 return kTfLiteError;
137 }
138
139 const TfLiteTensor& tfLitePadListTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
140 if (!IsValid(tfLiteContext, tfLitePadListTensor, operatorCode, nodeIndex))
141 {
142 return kTfLiteError;
143 }
144
145 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
146 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
147 {
148 return kTfLiteError;
149 }
150
151 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
152 const armnn::TensorInfo& blockShapeTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBlockShapeTensor);
153 const armnn::TensorInfo& padListTensorInfo = GetTensorInfoForTfLiteTensor(tfLitePadListTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +0100154 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000155
156 std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements());
157 std::memcpy(blockShape.data(), tfLiteBlockShapeTensor.data.data, blockShapeTensorInfo.GetNumBytes());
158
159 std::vector<unsigned int> padListVector(padListTensorInfo.GetNumElements());
160 std::memcpy(padListVector.data(), tfLitePadListTensor.data.data, padListTensorInfo.GetNumBytes());
161
162 size_t step = 2;
163 std::vector<std::pair<unsigned int, unsigned int>> padList;
164 for (unsigned int i = 0; i < padListTensorInfo.GetNumElements() / step; ++i)
165 {
166 padList.emplace_back(padListVector[i * step], padListVector[i * step + 1]);
167 }
168
169 armnn::SpaceToBatchNdDescriptor descriptor;
170 descriptor.m_BlockShape = blockShape;
171 descriptor.m_PadList = padList;
172 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
173
174 // Check if supported
175 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +0100176 armnn::BackendId setBackend;
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000177 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
178 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000179 FORWARD_LAYER_SUPPORT_FUNC("SPACE_TO_BATCH_ND",
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000180 tfLiteContext,
181 IsSpaceToBatchNdSupported,
182 delegateData.m_Backends,
183 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100184 setBackend,
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000185 inputTensorInfo,
186 outputTensorInfo,
187 descriptor);
188 };
189
190 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
191 // support for the operator
192 // If supported, VisitSpaceToBatchNdOperator will be called again to add the layer to the network as seen below
193 if (!delegateData.m_Network)
194 {
195 validateFunc(outputTensorInfo, isSupported);
196 return isSupported ? kTfLiteOk : kTfLiteError;
197 }
198
199 // Add a SpaceToBatch layer
200 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSpaceToBatchNdLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100201 layer->SetBackendId(setBackend);
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000202 ARMNN_ASSERT(layer != nullptr);
203
204 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
205 outputSlot.SetTensorInfo(outputTensorInfo);
206
Ryan OShea4c231de2023-01-17 15:19:20 +0000207 // try to connect the Constant Inputs if there are any
Mike Kelly07169c82023-08-02 13:23:09 +0100208 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Ryan OShea4c231de2023-01-17 15:19:20 +0000209 {
210 return kTfLiteError;
211 }
212
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000213 // Connect
214 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100215}
216
217} // namespace armnnDelegate