blob: 57b7f8074ee5896d86000f45697d4d1604848abb [file] [log] [blame]
Sadik Armagan34fa1bd2020-11-27 12:40:52 +00001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2020,2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan34fa1bd2020-11-27 12:40:52 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
Sadik Armagan34fa1bd2020-11-27 12:40:52 +00009
10#include <algorithm>
11#include <iterator>
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000012#include <vector>
13
14namespace armnnDelegate
15{
16
17constexpr unsigned int MaxNumOfTensorDimensions = 5U;
18
19TfLiteStatus VisitSplitOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t tfLiteSplitOperatorCode)
24{
25 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
26
27 auto* splitParameters = reinterpret_cast<TfLiteSplitParams*>(tfLiteNode->builtin_data);
28 const unsigned int numSplits = NonNegative(splitParameters->num_splits, nodeIndex);
29
30 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
31
32 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
33 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
34 if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
40 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitOperatorCode, nodeIndex))
41 {
42 return kTfLiteError;
43 }
44
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000045 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
46
Ryan OSheac229b3f2023-06-27 22:34:54 +010047 if (GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() != 1)
48 {
49 return kTfLiteError;
50 }
51
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000052 auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
53 std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
Matthew Sloyand30bfb52021-04-18 16:40:00 +010054 int32_t axis = axisTensorData[0];
55
56 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
57 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
58 {
59 // Square bracket denotes inclusive n while parenthesis denotes exclusive n
60 // E.g. Rank 4 tensor can have axis in range [-4, 3)
61 // -1 == 3, -2 == 2, -3 == 1, -4 == 0
62 TF_LITE_MAYBE_KERNEL_LOG(
63 tfLiteContext,
64 "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
65 axis, nodeIndex);
66 }
67 const unsigned int splitDim = ComputeWrappedIndex(axis, inputTensorInfo.GetNumDimensions());
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000068
69 std::vector<armnn::TensorInfo> outputs;
70 for (unsigned int i = 0; i < numSplits; ++i)
71 {
72 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
73 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitOperatorCode, nodeIndex))
74 {
75 return kTfLiteError;
76 }
Sadik Armagan90a119b2022-08-05 16:12:49 +010077 outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000078 }
79 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
80
81 auto inputDimSize = inputTensorInfo.GetNumDimensions();
82 if (inputDimSize > MaxNumOfTensorDimensions)
83 {
84 TF_LITE_MAYBE_KERNEL_LOG(
85 tfLiteContext,
86 "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
87 "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
88 return kTfLiteError;
89 }
90
91 std::vector<unsigned int> splitterDimSizes(inputDimSize);
92
93 // Add current input shape to splitterDimSizes
94 for (unsigned int i = 0; i < inputDimSize; ++i)
95 {
96 splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
97 }
98
99 if (splitterDimSizes[splitDim] % numSplits != 0)
100 {
101 TF_LITE_MAYBE_KERNEL_LOG(
102 tfLiteContext,
103 "TfLiteArmnnDelegate: Number of splits #%d must evenly divide the dimension #%d in node #%d: ",
104 numSplits, splitterDimSizes[splitDim], nodeIndex);
105 return kTfLiteError;
106 }
107 splitterDimSizes[splitDim] /= numSplits;
108
109 armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
Mike Kelly363b5722023-10-11 14:25:50 +0100110 splitDescriptor.SetAxis(axis);
111
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000112 for (unsigned int j = 0; j < numSplits; ++j)
113 {
114 // Set the size of the views.
115 for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
116 {
117 splitDescriptor.SetViewSize(j, dimIdx, splitterDimSizes[dimIdx]);
118 }
119 splitDescriptor.SetViewOriginCoord(j, splitDim, splitterDimSizes[splitDim] * j);
120 }
121
Cathal Corbett53837672022-09-01 11:34:37 +0100122 armnn::BackendId setBackend;
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000123 if (!delegateData.m_Network)
124 {
125 // Check if supported
126 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000127 FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000128 tfLiteContext,
129 IsSplitterSupported,
130 delegateData.m_Backends,
131 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100132 setBackend,
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000133 inputTensorInfo,
134 outputTensorInfos,
135 splitDescriptor);
136 return isSupported ? kTfLiteOk : kTfLiteError;
137 }
138
Mike Kelly07169c82023-08-02 13:23:09 +0100139 auto layerName = GetLayerName(armnn::LayerType::Splitter, nodeIndex);
140 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor, layerName.c_str());
Cathal Corbett53837672022-09-01 11:34:37 +0100141 layer->SetBackendId(setBackend);
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000142 ARMNN_ASSERT(layer != nullptr);
143
144 for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
145 {
146 layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
147 }
148
149 // Connect the input slots
150 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(0));
151
152 // Prepare output slots
153 for (unsigned int outputIndex = 0; outputIndex < layer->GetNumOutputSlots(); ++outputIndex)
154 {
155 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
156 delegateData.m_OutputSlotForNode[
157 static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &outputSlot;
158 }
159
160 return kTfLiteOk;
161}
162
163TfLiteStatus VisitSplitVOperator(DelegateData& delegateData,
164 TfLiteContext* tfLiteContext,
165 TfLiteNode* tfLiteNode,
166 int nodeIndex,
167 int32_t tfLiteSplitVOperatorCode)
168{
169 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
170
171 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
172 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
173 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitVOperatorCode, nodeIndex))
174 {
175 return kTfLiteError;
176 }
177
178 const TfLiteTensor& tfLiteSplitsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
179 if (!IsValid(tfLiteContext, tfLiteSplitsTensor, tfLiteSplitVOperatorCode, nodeIndex))
180 {
181 return kTfLiteError;
182 }
183
184 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
185 if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitVOperatorCode, nodeIndex))
186 {
187 return kTfLiteError;
188 }
189
190 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
191 const armnn::TensorInfo& splitsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteSplitsTensor);
Ryan OSheac229b3f2023-06-27 22:34:54 +0100192
193 if (splitsTensorInfo.GetNumDimensions() != 1)
194 {
195 return kTfLiteError;
196 }
197
198 if (GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() != 1)
199 {
200 return kTfLiteError;
201 }
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000202
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000203 auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
204 std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
Matthew Sloyand30bfb52021-04-18 16:40:00 +0100205 int32_t axis = axisTensorData[0];
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000206
Matthew Sloyand30bfb52021-04-18 16:40:00 +0100207 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
208 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000209 {
Matthew Sloyand30bfb52021-04-18 16:40:00 +0100210 TF_LITE_MAYBE_KERNEL_LOG(
211 tfLiteContext,
212 "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
213 axis, nodeIndex);
214 }
215 const unsigned int splitDim = ComputeWrappedIndex(axisTensorData[0], inputTensorInfo.GetNumDimensions());
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000216
217 auto* splitVParameters = reinterpret_cast<TfLiteSplitVParams*>(tfLiteNode->builtin_data);
218 unsigned int numSplits = 0;
219 if (splitVParameters)
220 {
221 numSplits = NonNegative(splitVParameters->num_splits, nodeIndex);
222 }
223 else
224 {
225 numSplits = splitsTensorInfo.GetNumElements();
226 }
227
228 if (numSplits <= 0)
229 {
230 TF_LITE_MAYBE_KERNEL_LOG(
231 tfLiteContext, "TfLiteArmnnDelegate: Invalid number of splits %d in node #%d",
232 numSplits, nodeIndex);
233 return kTfLiteError;
234 }
235
236 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
237 std::vector<armnn::TensorInfo> outputs;
238 for (unsigned int i = 0; i < numSplits; ++i)
239 {
240 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
241 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitVOperatorCode, nodeIndex))
242 {
243 return kTfLiteError;
244 }
Sadik Armagan90a119b2022-08-05 16:12:49 +0100245 outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000246 }
247 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
248
249 auto inputDimSize = inputTensorInfo.GetNumDimensions();
250 if (inputDimSize > MaxNumOfTensorDimensions)
251 {
252 TF_LITE_MAYBE_KERNEL_LOG(
253 tfLiteContext,
254 "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
255 "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
256 return kTfLiteError;
257 }
258
259 std::vector<int32_t> splitsTensorData(numSplits);
David Monahanc11ba462020-12-03 11:09:46 +0000260 std::memcpy(splitsTensorData.data(), tfLiteSplitsTensor.data.data, splitsTensorInfo.GetNumBytes());
261
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000262
263 unsigned int index = 0;
264 unsigned int inferredIndex = 0;
265 int numberOfInferred = 0;
266 int splitSum = 0;
267
268 for (auto splitData : splitsTensorData)
269 {
270 if (splitData < 0)
271 {
272 ++numberOfInferred;
273 inferredIndex = index;
274 }
275 else
276 {
277 splitSum += splitData;
278 }
279 ++index;
280 }
281
282 // Check for inferred axis
283 if (numberOfInferred == 0)
284 {
285 if (splitSum != armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]))
286 {
287 TF_LITE_MAYBE_KERNEL_LOG(
288 tfLiteContext, "TfLiteArmnnDelegate: SplitV split_sizes does not sum to the dimension of value along"
289 " split_dim in node #%d", nodeIndex);
290 return kTfLiteError;
291 }
292 }
293 else if (numberOfInferred == 1)
294 {
295 splitsTensorData[inferredIndex] = armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum;
296 }
297 else
298 {
299 TF_LITE_MAYBE_KERNEL_LOG(
300 tfLiteContext, "TfLiteArmnnDelegate: SplitV cannot infer split size for more than one split in node #%d",
301 nodeIndex);
302 return kTfLiteError;
303 }
304
305 armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
Mike Kelly363b5722023-10-11 14:25:50 +0100306 splitDescriptor.SetAxis(axis);
307
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000308 unsigned int accumSplit = 0;
309 for (unsigned int j = 0; j < numSplits; ++j)
310 {
311 unsigned int splitSize = armnn::numeric_cast<unsigned int>(splitsTensorData[j]);
312
313 // Set the size of the views.
314 for (unsigned int dimIdx = 0; dimIdx < inputTensorInfo.GetNumDimensions(); ++dimIdx)
315 {
316 unsigned int dimSize = inputTensorInfo.GetShape()[dimIdx];
317 if (dimIdx == splitDim)
318 {
319 dimSize = splitSize;
320 }
321 splitDescriptor.SetViewSize(j, dimIdx, dimSize);
322 }
323
324 splitDescriptor.SetViewOriginCoord(j, splitDim, accumSplit);
325 accumSplit += splitSize;
326 }
327
Cathal Corbett53837672022-09-01 11:34:37 +0100328 armnn::BackendId setBackend;
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000329 if (!delegateData.m_Network)
330 {
331 // Check if supported
332 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000333 FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000334 tfLiteContext,
335 IsSplitterSupported,
336 delegateData.m_Backends,
337 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100338 setBackend,
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000339 inputTensorInfo,
340 outputTensorInfos,
341 splitDescriptor);
342 return isSupported ? kTfLiteOk : kTfLiteError;
343 }
344
345 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100346 layer->SetBackendId(setBackend);
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000347 ARMNN_ASSERT(layer != nullptr);
348
349 for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
350 {
351 layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
352 }
353
Ryan OShea4c231de2023-01-17 15:19:20 +0000354 // try to connect the Constant Inputs if there are any
Mike Kelly07169c82023-08-02 13:23:09 +0100355 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Ryan OShea4c231de2023-01-17 15:19:20 +0000356 {
357 return kTfLiteError;
358 }
359
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000360 // Connect
361 return Connect(layer, tfLiteNode, delegateData);
362}
363
364} // namespace armnnDelegate