blob: b183b55c54cb14a6dd77a12e5b5541f55df40386 [file] [log] [blame]
Sadik Armagan34fa1bd2020-11-27 12:40:52 +00001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2020,2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan34fa1bd2020-11-27 12:40:52 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "DelegateUtils.hpp"
9
10#include <algorithm>
11#include <iterator>
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000012#include <vector>
13
14namespace armnnDelegate
15{
16
17constexpr unsigned int MaxNumOfTensorDimensions = 5U;
18
19TfLiteStatus VisitSplitOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t tfLiteSplitOperatorCode)
24{
25 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
26
27 auto* splitParameters = reinterpret_cast<TfLiteSplitParams*>(tfLiteNode->builtin_data);
28 const unsigned int numSplits = NonNegative(splitParameters->num_splits, nodeIndex);
29
30 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
31
32 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
33 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
34 if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
40 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitOperatorCode, nodeIndex))
41 {
42 return kTfLiteError;
43 }
44
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000045 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
46
Finn Williams019840d2020-11-30 17:43:28 +000047 ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000048 auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
49 std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
Matthew Sloyand30bfb52021-04-18 16:40:00 +010050 int32_t axis = axisTensorData[0];
51
52 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
53 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
54 {
55 // Square bracket denotes inclusive n while parenthesis denotes exclusive n
56 // E.g. Rank 4 tensor can have axis in range [-4, 3)
57 // -1 == 3, -2 == 2, -3 == 1, -4 == 0
58 TF_LITE_MAYBE_KERNEL_LOG(
59 tfLiteContext,
60 "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
61 axis, nodeIndex);
62 }
63 const unsigned int splitDim = ComputeWrappedIndex(axis, inputTensorInfo.GetNumDimensions());
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000064
65 std::vector<armnn::TensorInfo> outputs;
66 for (unsigned int i = 0; i < numSplits; ++i)
67 {
68 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
69 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitOperatorCode, nodeIndex))
70 {
71 return kTfLiteError;
72 }
Sadik Armagan90a119b2022-08-05 16:12:49 +010073 outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000074 }
75 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
76
77 auto inputDimSize = inputTensorInfo.GetNumDimensions();
78 if (inputDimSize > MaxNumOfTensorDimensions)
79 {
80 TF_LITE_MAYBE_KERNEL_LOG(
81 tfLiteContext,
82 "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
83 "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
84 return kTfLiteError;
85 }
86
87 std::vector<unsigned int> splitterDimSizes(inputDimSize);
88
89 // Add current input shape to splitterDimSizes
90 for (unsigned int i = 0; i < inputDimSize; ++i)
91 {
92 splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
93 }
94
95 if (splitterDimSizes[splitDim] % numSplits != 0)
96 {
97 TF_LITE_MAYBE_KERNEL_LOG(
98 tfLiteContext,
99 "TfLiteArmnnDelegate: Number of splits #%d must evenly divide the dimension #%d in node #%d: ",
100 numSplits, splitterDimSizes[splitDim], nodeIndex);
101 return kTfLiteError;
102 }
103 splitterDimSizes[splitDim] /= numSplits;
104
105 armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
106 for (unsigned int j = 0; j < numSplits; ++j)
107 {
108 // Set the size of the views.
109 for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
110 {
111 splitDescriptor.SetViewSize(j, dimIdx, splitterDimSizes[dimIdx]);
112 }
113 splitDescriptor.SetViewOriginCoord(j, splitDim, splitterDimSizes[splitDim] * j);
114 }
115
Cathal Corbett53837672022-09-01 11:34:37 +0100116 armnn::BackendId setBackend;
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000117 if (!delegateData.m_Network)
118 {
119 // Check if supported
120 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000121 FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000122 tfLiteContext,
123 IsSplitterSupported,
124 delegateData.m_Backends,
125 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100126 setBackend,
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000127 inputTensorInfo,
128 outputTensorInfos,
129 splitDescriptor);
130 return isSupported ? kTfLiteOk : kTfLiteError;
131 }
132
133 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100134 layer->SetBackendId(setBackend);
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000135 ARMNN_ASSERT(layer != nullptr);
136
137 for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
138 {
139 layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
140 }
141
142 // Connect the input slots
143 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(0));
144
145 // Prepare output slots
146 for (unsigned int outputIndex = 0; outputIndex < layer->GetNumOutputSlots(); ++outputIndex)
147 {
148 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
149 delegateData.m_OutputSlotForNode[
150 static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &outputSlot;
151 }
152
153 return kTfLiteOk;
154}
155
156TfLiteStatus VisitSplitVOperator(DelegateData& delegateData,
157 TfLiteContext* tfLiteContext,
158 TfLiteNode* tfLiteNode,
159 int nodeIndex,
160 int32_t tfLiteSplitVOperatorCode)
161{
162 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
163
164 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
165 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
166 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitVOperatorCode, nodeIndex))
167 {
168 return kTfLiteError;
169 }
170
171 const TfLiteTensor& tfLiteSplitsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
172 if (!IsValid(tfLiteContext, tfLiteSplitsTensor, tfLiteSplitVOperatorCode, nodeIndex))
173 {
174 return kTfLiteError;
175 }
176
177 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
178 if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitVOperatorCode, nodeIndex))
179 {
180 return kTfLiteError;
181 }
182
183 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
184 const armnn::TensorInfo& splitsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteSplitsTensor);
185 ARMNN_ASSERT(splitsTensorInfo.GetNumDimensions() == 1);
Finn Williams019840d2020-11-30 17:43:28 +0000186 ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000187
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000188 auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
189 std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
Matthew Sloyand30bfb52021-04-18 16:40:00 +0100190 int32_t axis = axisTensorData[0];
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000191
Matthew Sloyand30bfb52021-04-18 16:40:00 +0100192 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
193 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000194 {
Matthew Sloyand30bfb52021-04-18 16:40:00 +0100195 TF_LITE_MAYBE_KERNEL_LOG(
196 tfLiteContext,
197 "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
198 axis, nodeIndex);
199 }
200 const unsigned int splitDim = ComputeWrappedIndex(axisTensorData[0], inputTensorInfo.GetNumDimensions());
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000201
202 auto* splitVParameters = reinterpret_cast<TfLiteSplitVParams*>(tfLiteNode->builtin_data);
203 unsigned int numSplits = 0;
204 if (splitVParameters)
205 {
206 numSplits = NonNegative(splitVParameters->num_splits, nodeIndex);
207 }
208 else
209 {
210 numSplits = splitsTensorInfo.GetNumElements();
211 }
212
213 if (numSplits <= 0)
214 {
215 TF_LITE_MAYBE_KERNEL_LOG(
216 tfLiteContext, "TfLiteArmnnDelegate: Invalid number of splits %d in node #%d",
217 numSplits, nodeIndex);
218 return kTfLiteError;
219 }
220
221 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
222 std::vector<armnn::TensorInfo> outputs;
223 for (unsigned int i = 0; i < numSplits; ++i)
224 {
225 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
226 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitVOperatorCode, nodeIndex))
227 {
228 return kTfLiteError;
229 }
Sadik Armagan90a119b2022-08-05 16:12:49 +0100230 outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000231 }
232 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
233
234 auto inputDimSize = inputTensorInfo.GetNumDimensions();
235 if (inputDimSize > MaxNumOfTensorDimensions)
236 {
237 TF_LITE_MAYBE_KERNEL_LOG(
238 tfLiteContext,
239 "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
240 "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
241 return kTfLiteError;
242 }
243
244 std::vector<int32_t> splitsTensorData(numSplits);
David Monahanc11ba462020-12-03 11:09:46 +0000245 std::memcpy(splitsTensorData.data(), tfLiteSplitsTensor.data.data, splitsTensorInfo.GetNumBytes());
246
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000247
248 unsigned int index = 0;
249 unsigned int inferredIndex = 0;
250 int numberOfInferred = 0;
251 int splitSum = 0;
252
253 for (auto splitData : splitsTensorData)
254 {
255 if (splitData < 0)
256 {
257 ++numberOfInferred;
258 inferredIndex = index;
259 }
260 else
261 {
262 splitSum += splitData;
263 }
264 ++index;
265 }
266
267 // Check for inferred axis
268 if (numberOfInferred == 0)
269 {
270 if (splitSum != armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]))
271 {
272 TF_LITE_MAYBE_KERNEL_LOG(
273 tfLiteContext, "TfLiteArmnnDelegate: SplitV split_sizes does not sum to the dimension of value along"
274 " split_dim in node #%d", nodeIndex);
275 return kTfLiteError;
276 }
277 }
278 else if (numberOfInferred == 1)
279 {
280 splitsTensorData[inferredIndex] = armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum;
281 }
282 else
283 {
284 TF_LITE_MAYBE_KERNEL_LOG(
285 tfLiteContext, "TfLiteArmnnDelegate: SplitV cannot infer split size for more than one split in node #%d",
286 nodeIndex);
287 return kTfLiteError;
288 }
289
290 armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
291 unsigned int accumSplit = 0;
292 for (unsigned int j = 0; j < numSplits; ++j)
293 {
294 unsigned int splitSize = armnn::numeric_cast<unsigned int>(splitsTensorData[j]);
295
296 // Set the size of the views.
297 for (unsigned int dimIdx = 0; dimIdx < inputTensorInfo.GetNumDimensions(); ++dimIdx)
298 {
299 unsigned int dimSize = inputTensorInfo.GetShape()[dimIdx];
300 if (dimIdx == splitDim)
301 {
302 dimSize = splitSize;
303 }
304 splitDescriptor.SetViewSize(j, dimIdx, dimSize);
305 }
306
307 splitDescriptor.SetViewOriginCoord(j, splitDim, accumSplit);
308 accumSplit += splitSize;
309 }
310
Cathal Corbett53837672022-09-01 11:34:37 +0100311 armnn::BackendId setBackend;
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000312 if (!delegateData.m_Network)
313 {
314 // Check if supported
315 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000316 FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000317 tfLiteContext,
318 IsSplitterSupported,
319 delegateData.m_Backends,
320 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100321 setBackend,
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000322 inputTensorInfo,
323 outputTensorInfos,
324 splitDescriptor);
325 return isSupported ? kTfLiteOk : kTfLiteError;
326 }
327
328 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100329 layer->SetBackendId(setBackend);
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000330 ARMNN_ASSERT(layer != nullptr);
331
332 for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
333 {
334 layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
335 }
336
Ryan OShea4c231de2023-01-17 15:19:20 +0000337 // try to connect the Constant Inputs if there are any
338 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
339 {
340 return kTfLiteError;
341 }
342
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000343 // Connect
344 return Connect(layer, tfLiteNode, delegateData);
345}
346
347} // namespace armnnDelegate