blob: 30c6dbfc15ef729ed936890ba7c359a14ad88967 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/minimal_logging.h>
12
13namespace armnnDelegate
14{
15
16TfLiteStatus VisitBatchToSpaceNdOperator(DelegateData& delegateData,
17 TfLiteContext* tfLiteContext,
18 TfLiteNode* tfLiteNode,
19 int nodeIndex,
20 int32_t operatorCode)
21{
Matthew Sloyana35b40b2021-02-05 17:22:28 +000022 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
23 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
24
25 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
26 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
27 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
28 {
29 return kTfLiteError;
30 }
31
32 const TfLiteTensor& tfLiteBlockShapeTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
33 if (!IsValid(tfLiteContext, tfLiteBlockShapeTensor, operatorCode, nodeIndex))
34 {
35 return kTfLiteError;
36 }
37
38 const TfLiteTensor& tfLiteCropsTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
39 if (!IsValid(tfLiteContext, tfLiteCropsTensor, operatorCode, nodeIndex))
40 {
41 return kTfLiteError;
42 }
43
44 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
45 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
46 {
47 return kTfLiteError;
48 }
49
50 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
51 const armnn::TensorInfo& blockShapeTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBlockShapeTensor);
52 const armnn::TensorInfo& cropsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteCropsTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010053 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Matthew Sloyana35b40b2021-02-05 17:22:28 +000054
55 std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements());
56 ::memcpy(blockShape.data(), tfLiteBlockShapeTensor.data.data, blockShapeTensorInfo.GetNumBytes());
57
58 std::vector<unsigned int> cropsVector(cropsTensorInfo.GetNumElements());
59 std::memcpy(cropsVector.data(), tfLiteCropsTensor.data.data, cropsTensorInfo.GetNumBytes());
60
61 size_t step = 2;
62 std::vector<std::pair<unsigned int, unsigned int>> crops;
63 for (unsigned int i = 0; i < cropsTensorInfo.GetNumElements() / step; ++i)
64 {
65 crops.emplace_back(cropsVector[i * step], cropsVector[i * step + 1]);
66 }
67
68 armnn::BatchToSpaceNdDescriptor descriptor;
69 descriptor.m_BlockShape = blockShape;
70 descriptor.m_Crops = crops;
71 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
72
73 // Check if supported
74 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +010075 armnn::BackendId setBackend;
Matthew Sloyana35b40b2021-02-05 17:22:28 +000076 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
77 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000078 FORWARD_LAYER_SUPPORT_FUNC("BATCH_TO_SPACE_ND",
Matthew Sloyana35b40b2021-02-05 17:22:28 +000079 tfLiteContext,
80 IsBatchToSpaceNdSupported,
81 delegateData.m_Backends,
82 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010083 setBackend,
Matthew Sloyana35b40b2021-02-05 17:22:28 +000084 inputTensorInfo,
85 outputTensorInfo,
86 descriptor);
87 };
88
89 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
90 // support for the operator
91 // If supported, VisitBatchToSpaceNdOperator will be called again to add the layer to the network as seen below
92 if (!delegateData.m_Network)
93 {
94 validateFunc(outputTensorInfo, isSupported);
95 return isSupported ? kTfLiteOk : kTfLiteError;
96 }
97
98 // Add a BatchToSpace layer
99 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchToSpaceNdLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100100 layer->SetBackendId(setBackend);
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000101 ARMNN_ASSERT(layer != nullptr);
102
103 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
104 outputSlot.SetTensorInfo(outputTensorInfo);
105
Ryan OShea4c231de2023-01-17 15:19:20 +0000106 // try to connect the Constant Inputs if there are any
107 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
108 {
109 return kTfLiteError;
110 }
111
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000112 // Connect
113 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100114}
115
116TfLiteStatus VisitSpaceToBatchNdOperator(DelegateData& delegateData,
117 TfLiteContext* tfLiteContext,
118 TfLiteNode* tfLiteNode,
119 int nodeIndex,
120 int32_t operatorCode)
121{
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000122 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
123 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
124
125 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
126 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
127 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
128 {
129 return kTfLiteError;
130 }
131
132 const TfLiteTensor& tfLiteBlockShapeTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
133 if (!IsValid(tfLiteContext, tfLiteBlockShapeTensor, operatorCode, nodeIndex))
134 {
135 return kTfLiteError;
136 }
137
138 const TfLiteTensor& tfLitePadListTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
139 if (!IsValid(tfLiteContext, tfLitePadListTensor, operatorCode, nodeIndex))
140 {
141 return kTfLiteError;
142 }
143
144 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
145 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
146 {
147 return kTfLiteError;
148 }
149
150 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
151 const armnn::TensorInfo& blockShapeTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBlockShapeTensor);
152 const armnn::TensorInfo& padListTensorInfo = GetTensorInfoForTfLiteTensor(tfLitePadListTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +0100153 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000154
155 std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements());
156 std::memcpy(blockShape.data(), tfLiteBlockShapeTensor.data.data, blockShapeTensorInfo.GetNumBytes());
157
158 std::vector<unsigned int> padListVector(padListTensorInfo.GetNumElements());
159 std::memcpy(padListVector.data(), tfLitePadListTensor.data.data, padListTensorInfo.GetNumBytes());
160
161 size_t step = 2;
162 std::vector<std::pair<unsigned int, unsigned int>> padList;
163 for (unsigned int i = 0; i < padListTensorInfo.GetNumElements() / step; ++i)
164 {
165 padList.emplace_back(padListVector[i * step], padListVector[i * step + 1]);
166 }
167
168 armnn::SpaceToBatchNdDescriptor descriptor;
169 descriptor.m_BlockShape = blockShape;
170 descriptor.m_PadList = padList;
171 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
172
173 // Check if supported
174 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +0100175 armnn::BackendId setBackend;
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000176 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
177 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000178 FORWARD_LAYER_SUPPORT_FUNC("SPACE_TO_BATCH_ND",
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000179 tfLiteContext,
180 IsSpaceToBatchNdSupported,
181 delegateData.m_Backends,
182 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100183 setBackend,
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000184 inputTensorInfo,
185 outputTensorInfo,
186 descriptor);
187 };
188
189 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
190 // support for the operator
191 // If supported, VisitSpaceToBatchNdOperator will be called again to add the layer to the network as seen below
192 if (!delegateData.m_Network)
193 {
194 validateFunc(outputTensorInfo, isSupported);
195 return isSupported ? kTfLiteOk : kTfLiteError;
196 }
197
198 // Add a SpaceToBatch layer
199 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSpaceToBatchNdLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100200 layer->SetBackendId(setBackend);
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000201 ARMNN_ASSERT(layer != nullptr);
202
203 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
204 outputSlot.SetTensorInfo(outputTensorInfo);
205
Ryan OShea4c231de2023-01-17 15:19:20 +0000206 // try to connect the Constant Inputs if there are any
207 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
208 {
209 return kTfLiteError;
210 }
211
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000212 // Connect
213 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100214}
215
216} // namespace armnnDelegate