blob: a535585699b9fd9d8b0e681b11876d559aa9137f [file] [log] [blame]
Sadik Armagan34fa1bd2020-11-27 12:40:52 +00001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "DelegateUtils.hpp"
9
10#include <algorithm>
11#include <iterator>
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000012#include <vector>
13
14namespace armnnDelegate
15{
16
17constexpr unsigned int MaxNumOfTensorDimensions = 5U;
18
19TfLiteStatus VisitSplitOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t tfLiteSplitOperatorCode)
24{
25 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
26
27 auto* splitParameters = reinterpret_cast<TfLiteSplitParams*>(tfLiteNode->builtin_data);
28 const unsigned int numSplits = NonNegative(splitParameters->num_splits, nodeIndex);
29
30 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
31
32 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
33 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
34 if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
40 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitOperatorCode, nodeIndex))
41 {
42 return kTfLiteError;
43 }
44
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000045 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
46
Finn Williams019840d2020-11-30 17:43:28 +000047 ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000048 auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
49 std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
Matthew Sloyand30bfb52021-04-18 16:40:00 +010050 int32_t axis = axisTensorData[0];
51
52 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
53 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
54 {
55 // Square bracket denotes inclusive n while parenthesis denotes exclusive n
56 // E.g. Rank 4 tensor can have axis in range [-4, 3)
57 // -1 == 3, -2 == 2, -3 == 1, -4 == 0
58 TF_LITE_MAYBE_KERNEL_LOG(
59 tfLiteContext,
60 "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
61 axis, nodeIndex);
62 }
63 const unsigned int splitDim = ComputeWrappedIndex(axis, inputTensorInfo.GetNumDimensions());
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000064
65 std::vector<armnn::TensorInfo> outputs;
66 for (unsigned int i = 0; i < numSplits; ++i)
67 {
68 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
69 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitOperatorCode, nodeIndex))
70 {
71 return kTfLiteError;
72 }
Sadik Armagan90a119b2022-08-05 16:12:49 +010073 outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000074 }
75 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
76
77 auto inputDimSize = inputTensorInfo.GetNumDimensions();
78 if (inputDimSize > MaxNumOfTensorDimensions)
79 {
80 TF_LITE_MAYBE_KERNEL_LOG(
81 tfLiteContext,
82 "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
83 "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
84 return kTfLiteError;
85 }
86
87 std::vector<unsigned int> splitterDimSizes(inputDimSize);
88
89 // Add current input shape to splitterDimSizes
90 for (unsigned int i = 0; i < inputDimSize; ++i)
91 {
92 splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
93 }
94
95 if (splitterDimSizes[splitDim] % numSplits != 0)
96 {
97 TF_LITE_MAYBE_KERNEL_LOG(
98 tfLiteContext,
99 "TfLiteArmnnDelegate: Number of splits #%d must evenly divide the dimension #%d in node #%d: ",
100 numSplits, splitterDimSizes[splitDim], nodeIndex);
101 return kTfLiteError;
102 }
103 splitterDimSizes[splitDim] /= numSplits;
104
105 armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
106 for (unsigned int j = 0; j < numSplits; ++j)
107 {
108 // Set the size of the views.
109 for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
110 {
111 splitDescriptor.SetViewSize(j, dimIdx, splitterDimSizes[dimIdx]);
112 }
113 splitDescriptor.SetViewOriginCoord(j, splitDim, splitterDimSizes[splitDim] * j);
114 }
115
116 if (!delegateData.m_Network)
117 {
118 // Check if supported
119 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000120 FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000121 tfLiteContext,
122 IsSplitterSupported,
123 delegateData.m_Backends,
124 isSupported,
125 inputTensorInfo,
126 outputTensorInfos,
127 splitDescriptor);
128 return isSupported ? kTfLiteOk : kTfLiteError;
129 }
130
131 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
132 ARMNN_ASSERT(layer != nullptr);
133
134 for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
135 {
136 layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
137 }
138
139 // Connect the input slots
140 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(0));
141
142 // Prepare output slots
143 for (unsigned int outputIndex = 0; outputIndex < layer->GetNumOutputSlots(); ++outputIndex)
144 {
145 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
146 delegateData.m_OutputSlotForNode[
147 static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &outputSlot;
148 }
149
150 return kTfLiteOk;
151}
152
153TfLiteStatus VisitSplitVOperator(DelegateData& delegateData,
154 TfLiteContext* tfLiteContext,
155 TfLiteNode* tfLiteNode,
156 int nodeIndex,
157 int32_t tfLiteSplitVOperatorCode)
158{
159 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
160
161 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
162 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
163 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitVOperatorCode, nodeIndex))
164 {
165 return kTfLiteError;
166 }
167
168 const TfLiteTensor& tfLiteSplitsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
169 if (!IsValid(tfLiteContext, tfLiteSplitsTensor, tfLiteSplitVOperatorCode, nodeIndex))
170 {
171 return kTfLiteError;
172 }
173
174 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
175 if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitVOperatorCode, nodeIndex))
176 {
177 return kTfLiteError;
178 }
179
180 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
181 const armnn::TensorInfo& splitsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteSplitsTensor);
182 ARMNN_ASSERT(splitsTensorInfo.GetNumDimensions() == 1);
Finn Williams019840d2020-11-30 17:43:28 +0000183 ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000184
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000185 auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
186 std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
Matthew Sloyand30bfb52021-04-18 16:40:00 +0100187 int32_t axis = axisTensorData[0];
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000188
Matthew Sloyand30bfb52021-04-18 16:40:00 +0100189 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
190 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000191 {
Matthew Sloyand30bfb52021-04-18 16:40:00 +0100192 TF_LITE_MAYBE_KERNEL_LOG(
193 tfLiteContext,
194 "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
195 axis, nodeIndex);
196 }
197 const unsigned int splitDim = ComputeWrappedIndex(axisTensorData[0], inputTensorInfo.GetNumDimensions());
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000198
199 auto* splitVParameters = reinterpret_cast<TfLiteSplitVParams*>(tfLiteNode->builtin_data);
200 unsigned int numSplits = 0;
201 if (splitVParameters)
202 {
203 numSplits = NonNegative(splitVParameters->num_splits, nodeIndex);
204 }
205 else
206 {
207 numSplits = splitsTensorInfo.GetNumElements();
208 }
209
210 if (numSplits <= 0)
211 {
212 TF_LITE_MAYBE_KERNEL_LOG(
213 tfLiteContext, "TfLiteArmnnDelegate: Invalid number of splits %d in node #%d",
214 numSplits, nodeIndex);
215 return kTfLiteError;
216 }
217
218 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
219 std::vector<armnn::TensorInfo> outputs;
220 for (unsigned int i = 0; i < numSplits; ++i)
221 {
222 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
223 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitVOperatorCode, nodeIndex))
224 {
225 return kTfLiteError;
226 }
Sadik Armagan90a119b2022-08-05 16:12:49 +0100227 outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000228 }
229 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
230
231 auto inputDimSize = inputTensorInfo.GetNumDimensions();
232 if (inputDimSize > MaxNumOfTensorDimensions)
233 {
234 TF_LITE_MAYBE_KERNEL_LOG(
235 tfLiteContext,
236 "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
237 "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
238 return kTfLiteError;
239 }
240
241 std::vector<int32_t> splitsTensorData(numSplits);
David Monahanc11ba462020-12-03 11:09:46 +0000242 std::memcpy(splitsTensorData.data(), tfLiteSplitsTensor.data.data, splitsTensorInfo.GetNumBytes());
243
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000244
245 unsigned int index = 0;
246 unsigned int inferredIndex = 0;
247 int numberOfInferred = 0;
248 int splitSum = 0;
249
250 for (auto splitData : splitsTensorData)
251 {
252 if (splitData < 0)
253 {
254 ++numberOfInferred;
255 inferredIndex = index;
256 }
257 else
258 {
259 splitSum += splitData;
260 }
261 ++index;
262 }
263
264 // Check for inferred axis
265 if (numberOfInferred == 0)
266 {
267 if (splitSum != armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]))
268 {
269 TF_LITE_MAYBE_KERNEL_LOG(
270 tfLiteContext, "TfLiteArmnnDelegate: SplitV split_sizes does not sum to the dimension of value along"
271 " split_dim in node #%d", nodeIndex);
272 return kTfLiteError;
273 }
274 }
275 else if (numberOfInferred == 1)
276 {
277 splitsTensorData[inferredIndex] = armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum;
278 }
279 else
280 {
281 TF_LITE_MAYBE_KERNEL_LOG(
282 tfLiteContext, "TfLiteArmnnDelegate: SplitV cannot infer split size for more than one split in node #%d",
283 nodeIndex);
284 return kTfLiteError;
285 }
286
287 armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
288 unsigned int accumSplit = 0;
289 for (unsigned int j = 0; j < numSplits; ++j)
290 {
291 unsigned int splitSize = armnn::numeric_cast<unsigned int>(splitsTensorData[j]);
292
293 // Set the size of the views.
294 for (unsigned int dimIdx = 0; dimIdx < inputTensorInfo.GetNumDimensions(); ++dimIdx)
295 {
296 unsigned int dimSize = inputTensorInfo.GetShape()[dimIdx];
297 if (dimIdx == splitDim)
298 {
299 dimSize = splitSize;
300 }
301 splitDescriptor.SetViewSize(j, dimIdx, dimSize);
302 }
303
304 splitDescriptor.SetViewOriginCoord(j, splitDim, accumSplit);
305 accumSplit += splitSize;
306 }
307
308 if (!delegateData.m_Network)
309 {
310 // Check if supported
311 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000312 FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000313 tfLiteContext,
314 IsSplitterSupported,
315 delegateData.m_Backends,
316 isSupported,
317 inputTensorInfo,
318 outputTensorInfos,
319 splitDescriptor);
320 return isSupported ? kTfLiteOk : kTfLiteError;
321 }
322
323 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
324 ARMNN_ASSERT(layer != nullptr);
325
326 for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
327 {
328 layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
329 }
330
331 // Connect
332 return Connect(layer, tfLiteNode, delegateData);
333}
334
335} // namespace armnnDelegate