blob: aaa610259f7139578cf6c6e5070bd29348347552 [file] [log] [blame]
Sadik Armagan34fa1bd2020-11-27 12:40:52 +00001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2020,2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan34fa1bd2020-11-27 12:40:52 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
Sadik Armagan34fa1bd2020-11-27 12:40:52 +00009
10#include <algorithm>
11#include <iterator>
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000012#include <vector>
13
14namespace armnnDelegate
15{
16
17constexpr unsigned int MaxNumOfTensorDimensions = 5U;
18
19TfLiteStatus VisitSplitOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t tfLiteSplitOperatorCode)
24{
25 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
26
27 auto* splitParameters = reinterpret_cast<TfLiteSplitParams*>(tfLiteNode->builtin_data);
28 const unsigned int numSplits = NonNegative(splitParameters->num_splits, nodeIndex);
29
30 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
31
32 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
33 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
34 if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
40 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitOperatorCode, nodeIndex))
41 {
42 return kTfLiteError;
43 }
44
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000045 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
46
Ryan OSheac229b3f2023-06-27 22:34:54 +010047 if (GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() != 1)
48 {
49 return kTfLiteError;
50 }
51
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000052 auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
53 std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
Matthew Sloyand30bfb52021-04-18 16:40:00 +010054 int32_t axis = axisTensorData[0];
55
56 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
57 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
58 {
59 // Square bracket denotes inclusive n while parenthesis denotes exclusive n
60 // E.g. Rank 4 tensor can have axis in range [-4, 3)
61 // -1 == 3, -2 == 2, -3 == 1, -4 == 0
62 TF_LITE_MAYBE_KERNEL_LOG(
63 tfLiteContext,
64 "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
65 axis, nodeIndex);
66 }
67 const unsigned int splitDim = ComputeWrappedIndex(axis, inputTensorInfo.GetNumDimensions());
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000068
69 std::vector<armnn::TensorInfo> outputs;
70 for (unsigned int i = 0; i < numSplits; ++i)
71 {
72 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
73 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitOperatorCode, nodeIndex))
74 {
75 return kTfLiteError;
76 }
Sadik Armagan90a119b2022-08-05 16:12:49 +010077 outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000078 }
79 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
80
81 auto inputDimSize = inputTensorInfo.GetNumDimensions();
82 if (inputDimSize > MaxNumOfTensorDimensions)
83 {
84 TF_LITE_MAYBE_KERNEL_LOG(
85 tfLiteContext,
86 "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
87 "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
88 return kTfLiteError;
89 }
90
91 std::vector<unsigned int> splitterDimSizes(inputDimSize);
92
93 // Add current input shape to splitterDimSizes
94 for (unsigned int i = 0; i < inputDimSize; ++i)
95 {
96 splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
97 }
98
99 if (splitterDimSizes[splitDim] % numSplits != 0)
100 {
101 TF_LITE_MAYBE_KERNEL_LOG(
102 tfLiteContext,
103 "TfLiteArmnnDelegate: Number of splits #%d must evenly divide the dimension #%d in node #%d: ",
104 numSplits, splitterDimSizes[splitDim], nodeIndex);
105 return kTfLiteError;
106 }
107 splitterDimSizes[splitDim] /= numSplits;
108
109 armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
110 for (unsigned int j = 0; j < numSplits; ++j)
111 {
112 // Set the size of the views.
113 for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
114 {
115 splitDescriptor.SetViewSize(j, dimIdx, splitterDimSizes[dimIdx]);
116 }
117 splitDescriptor.SetViewOriginCoord(j, splitDim, splitterDimSizes[splitDim] * j);
118 }
119
Cathal Corbett53837672022-09-01 11:34:37 +0100120 armnn::BackendId setBackend;
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000121 if (!delegateData.m_Network)
122 {
123 // Check if supported
124 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000125 FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000126 tfLiteContext,
127 IsSplitterSupported,
128 delegateData.m_Backends,
129 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100130 setBackend,
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000131 inputTensorInfo,
132 outputTensorInfos,
133 splitDescriptor);
134 return isSupported ? kTfLiteOk : kTfLiteError;
135 }
136
Mike Kelly07169c82023-08-02 13:23:09 +0100137 auto layerName = GetLayerName(armnn::LayerType::Splitter, nodeIndex);
138 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor, layerName.c_str());
Cathal Corbett53837672022-09-01 11:34:37 +0100139 layer->SetBackendId(setBackend);
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000140 ARMNN_ASSERT(layer != nullptr);
141
142 for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
143 {
144 layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
145 }
146
147 // Connect the input slots
148 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(0));
149
150 // Prepare output slots
151 for (unsigned int outputIndex = 0; outputIndex < layer->GetNumOutputSlots(); ++outputIndex)
152 {
153 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
154 delegateData.m_OutputSlotForNode[
155 static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &outputSlot;
156 }
157
158 return kTfLiteOk;
159}
160
161TfLiteStatus VisitSplitVOperator(DelegateData& delegateData,
162 TfLiteContext* tfLiteContext,
163 TfLiteNode* tfLiteNode,
164 int nodeIndex,
165 int32_t tfLiteSplitVOperatorCode)
166{
167 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
168
169 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
170 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
171 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitVOperatorCode, nodeIndex))
172 {
173 return kTfLiteError;
174 }
175
176 const TfLiteTensor& tfLiteSplitsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
177 if (!IsValid(tfLiteContext, tfLiteSplitsTensor, tfLiteSplitVOperatorCode, nodeIndex))
178 {
179 return kTfLiteError;
180 }
181
182 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
183 if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitVOperatorCode, nodeIndex))
184 {
185 return kTfLiteError;
186 }
187
188 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
189 const armnn::TensorInfo& splitsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteSplitsTensor);
Ryan OSheac229b3f2023-06-27 22:34:54 +0100190
191 if (splitsTensorInfo.GetNumDimensions() != 1)
192 {
193 return kTfLiteError;
194 }
195
196 if (GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() != 1)
197 {
198 return kTfLiteError;
199 }
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000200
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000201 auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
202 std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
Matthew Sloyand30bfb52021-04-18 16:40:00 +0100203 int32_t axis = axisTensorData[0];
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000204
Matthew Sloyand30bfb52021-04-18 16:40:00 +0100205 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
206 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000207 {
Matthew Sloyand30bfb52021-04-18 16:40:00 +0100208 TF_LITE_MAYBE_KERNEL_LOG(
209 tfLiteContext,
210 "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
211 axis, nodeIndex);
212 }
213 const unsigned int splitDim = ComputeWrappedIndex(axisTensorData[0], inputTensorInfo.GetNumDimensions());
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000214
215 auto* splitVParameters = reinterpret_cast<TfLiteSplitVParams*>(tfLiteNode->builtin_data);
216 unsigned int numSplits = 0;
217 if (splitVParameters)
218 {
219 numSplits = NonNegative(splitVParameters->num_splits, nodeIndex);
220 }
221 else
222 {
223 numSplits = splitsTensorInfo.GetNumElements();
224 }
225
226 if (numSplits <= 0)
227 {
228 TF_LITE_MAYBE_KERNEL_LOG(
229 tfLiteContext, "TfLiteArmnnDelegate: Invalid number of splits %d in node #%d",
230 numSplits, nodeIndex);
231 return kTfLiteError;
232 }
233
234 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
235 std::vector<armnn::TensorInfo> outputs;
236 for (unsigned int i = 0; i < numSplits; ++i)
237 {
238 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
239 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitVOperatorCode, nodeIndex))
240 {
241 return kTfLiteError;
242 }
Sadik Armagan90a119b2022-08-05 16:12:49 +0100243 outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000244 }
245 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
246
247 auto inputDimSize = inputTensorInfo.GetNumDimensions();
248 if (inputDimSize > MaxNumOfTensorDimensions)
249 {
250 TF_LITE_MAYBE_KERNEL_LOG(
251 tfLiteContext,
252 "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
253 "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
254 return kTfLiteError;
255 }
256
257 std::vector<int32_t> splitsTensorData(numSplits);
David Monahanc11ba462020-12-03 11:09:46 +0000258 std::memcpy(splitsTensorData.data(), tfLiteSplitsTensor.data.data, splitsTensorInfo.GetNumBytes());
259
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000260
261 unsigned int index = 0;
262 unsigned int inferredIndex = 0;
263 int numberOfInferred = 0;
264 int splitSum = 0;
265
266 for (auto splitData : splitsTensorData)
267 {
268 if (splitData < 0)
269 {
270 ++numberOfInferred;
271 inferredIndex = index;
272 }
273 else
274 {
275 splitSum += splitData;
276 }
277 ++index;
278 }
279
280 // Check for inferred axis
281 if (numberOfInferred == 0)
282 {
283 if (splitSum != armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]))
284 {
285 TF_LITE_MAYBE_KERNEL_LOG(
286 tfLiteContext, "TfLiteArmnnDelegate: SplitV split_sizes does not sum to the dimension of value along"
287 " split_dim in node #%d", nodeIndex);
288 return kTfLiteError;
289 }
290 }
291 else if (numberOfInferred == 1)
292 {
293 splitsTensorData[inferredIndex] = armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum;
294 }
295 else
296 {
297 TF_LITE_MAYBE_KERNEL_LOG(
298 tfLiteContext, "TfLiteArmnnDelegate: SplitV cannot infer split size for more than one split in node #%d",
299 nodeIndex);
300 return kTfLiteError;
301 }
302
303 armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
304 unsigned int accumSplit = 0;
305 for (unsigned int j = 0; j < numSplits; ++j)
306 {
307 unsigned int splitSize = armnn::numeric_cast<unsigned int>(splitsTensorData[j]);
308
309 // Set the size of the views.
310 for (unsigned int dimIdx = 0; dimIdx < inputTensorInfo.GetNumDimensions(); ++dimIdx)
311 {
312 unsigned int dimSize = inputTensorInfo.GetShape()[dimIdx];
313 if (dimIdx == splitDim)
314 {
315 dimSize = splitSize;
316 }
317 splitDescriptor.SetViewSize(j, dimIdx, dimSize);
318 }
319
320 splitDescriptor.SetViewOriginCoord(j, splitDim, accumSplit);
321 accumSplit += splitSize;
322 }
323
Cathal Corbett53837672022-09-01 11:34:37 +0100324 armnn::BackendId setBackend;
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000325 if (!delegateData.m_Network)
326 {
327 // Check if supported
328 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +0000329 FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000330 tfLiteContext,
331 IsSplitterSupported,
332 delegateData.m_Backends,
333 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100334 setBackend,
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000335 inputTensorInfo,
336 outputTensorInfos,
337 splitDescriptor);
338 return isSupported ? kTfLiteOk : kTfLiteError;
339 }
340
341 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100342 layer->SetBackendId(setBackend);
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000343 ARMNN_ASSERT(layer != nullptr);
344
345 for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
346 {
347 layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
348 }
349
Ryan OShea4c231de2023-01-17 15:19:20 +0000350 // try to connect the Constant Inputs if there are any
Mike Kelly07169c82023-08-02 13:23:09 +0100351 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Ryan OShea4c231de2023-01-17 15:19:20 +0000352 {
353 return kTfLiteError;
354 }
355
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000356 // Connect
357 return Connect(layer, tfLiteNode, delegateData);
358}
359
360} // namespace armnnDelegate