blob: 5f8d5cc317177f1132b81dd2129dadb6549ba0f8 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/minimal_logging.h>
12
13namespace armnnDelegate
14{
15
16TfLiteStatus VisitBatchToSpaceNdOperator(DelegateData& delegateData,
17 TfLiteContext* tfLiteContext,
18 TfLiteNode* tfLiteNode,
19 int nodeIndex,
20 int32_t operatorCode)
21{
Matthew Sloyana35b40b2021-02-05 17:22:28 +000022 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
23 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
24
25 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
26 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
27 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
28 {
29 return kTfLiteError;
30 }
31
32 const TfLiteTensor& tfLiteBlockShapeTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
33 if (!IsValid(tfLiteContext, tfLiteBlockShapeTensor, operatorCode, nodeIndex))
34 {
35 return kTfLiteError;
36 }
37
38 const TfLiteTensor& tfLiteCropsTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
39 if (!IsValid(tfLiteContext, tfLiteCropsTensor, operatorCode, nodeIndex))
40 {
41 return kTfLiteError;
42 }
43
44 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
45 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
46 {
47 return kTfLiteError;
48 }
49
50 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
51 const armnn::TensorInfo& blockShapeTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBlockShapeTensor);
52 const armnn::TensorInfo& cropsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteCropsTensor);
53 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
54
55 std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements());
56 ::memcpy(blockShape.data(), tfLiteBlockShapeTensor.data.data, blockShapeTensorInfo.GetNumBytes());
57
58 std::vector<unsigned int> cropsVector(cropsTensorInfo.GetNumElements());
59 std::memcpy(cropsVector.data(), tfLiteCropsTensor.data.data, cropsTensorInfo.GetNumBytes());
60
61 size_t step = 2;
62 std::vector<std::pair<unsigned int, unsigned int>> crops;
63 for (unsigned int i = 0; i < cropsTensorInfo.GetNumElements() / step; ++i)
64 {
65 crops.emplace_back(cropsVector[i * step], cropsVector[i * step + 1]);
66 }
67
68 armnn::BatchToSpaceNdDescriptor descriptor;
69 descriptor.m_BlockShape = blockShape;
70 descriptor.m_Crops = crops;
71 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
72
73 // Check if supported
74 bool isSupported = false;
75 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
76 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +000077 FORWARD_LAYER_SUPPORT_FUNC("BATCH_TO_SPACE_ND",
Matthew Sloyana35b40b2021-02-05 17:22:28 +000078 tfLiteContext,
79 IsBatchToSpaceNdSupported,
80 delegateData.m_Backends,
81 isSupported,
82 inputTensorInfo,
83 outputTensorInfo,
84 descriptor);
85 };
86
87 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
88 // support for the operator
89 // If supported, VisitBatchToSpaceNdOperator will be called again to add the layer to the network as seen below
90 if (!delegateData.m_Network)
91 {
92 validateFunc(outputTensorInfo, isSupported);
93 return isSupported ? kTfLiteOk : kTfLiteError;
94 }
95
96 // Add a BatchToSpace layer
97 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchToSpaceNdLayer(descriptor);
98 ARMNN_ASSERT(layer != nullptr);
99
100 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
101 outputSlot.SetTensorInfo(outputTensorInfo);
102
103 // Connect
104 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100105}
106
107TfLiteStatus VisitSpaceToBatchNdOperator(DelegateData& delegateData,
108 TfLiteContext* tfLiteContext,
109 TfLiteNode* tfLiteNode,
110 int nodeIndex,
111 int32_t operatorCode)
112{
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000113 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
114 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
115
116 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
117 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
118 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
119 {
120 return kTfLiteError;
121 }
122
123 const TfLiteTensor& tfLiteBlockShapeTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
124 if (!IsValid(tfLiteContext, tfLiteBlockShapeTensor, operatorCode, nodeIndex))
125 {
126 return kTfLiteError;
127 }
128
129 const TfLiteTensor& tfLitePadListTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
130 if (!IsValid(tfLiteContext, tfLitePadListTensor, operatorCode, nodeIndex))
131 {
132 return kTfLiteError;
133 }
134
135 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
136 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
137 {
138 return kTfLiteError;
139 }
140
141 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
142 const armnn::TensorInfo& blockShapeTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBlockShapeTensor);
143 const armnn::TensorInfo& padListTensorInfo = GetTensorInfoForTfLiteTensor(tfLitePadListTensor);
144 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
145
146 std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements());
147 std::memcpy(blockShape.data(), tfLiteBlockShapeTensor.data.data, blockShapeTensorInfo.GetNumBytes());
148
149 std::vector<unsigned int> padListVector(padListTensorInfo.GetNumElements());
150 std::memcpy(padListVector.data(), tfLitePadListTensor.data.data, padListTensorInfo.GetNumBytes());
151
152 size_t step = 2;
153 std::vector<std::pair<unsigned int, unsigned int>> padList;
154 for (unsigned int i = 0; i < padListTensorInfo.GetNumElements() / step; ++i)
155 {
156 padList.emplace_back(padListVector[i * step], padListVector[i * step + 1]);
157 }
158
159 armnn::SpaceToBatchNdDescriptor descriptor;
160 descriptor.m_BlockShape = blockShape;
161 descriptor.m_PadList = padList;
162 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
163
164 // Check if supported
165 bool isSupported = false;
166 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
167 {
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000168 FORWARD_LAYER_SUPPORT_FUNC("SPACE_TO_BATCH_ND",
Matthew Sloyana35b40b2021-02-05 17:22:28 +0000169 tfLiteContext,
170 IsSpaceToBatchNdSupported,
171 delegateData.m_Backends,
172 isSupported,
173 inputTensorInfo,
174 outputTensorInfo,
175 descriptor);
176 };
177
178 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
179 // support for the operator
180 // If supported, VisitSpaceToBatchNdOperator will be called again to add the layer to the network as seen below
181 if (!delegateData.m_Network)
182 {
183 validateFunc(outputTensorInfo, isSupported);
184 return isSupported ? kTfLiteOk : kTfLiteError;
185 }
186
187 // Add a SpaceToBatch layer
188 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSpaceToBatchNdLayer(descriptor);
189 ARMNN_ASSERT(layer != nullptr);
190
191 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
192 outputSlot.SetTensorInfo(outputTensorInfo);
193
194 // Connect
195 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100196}
197
198} // namespace armnnDelegate