blob: 199f46b126550ceba20c0247a12e2043f593db4a [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
David Monahanc833cef2023-05-03 15:53:03 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9#include <DelegateUtils.hpp>
10
11#include <algorithm>
12#include <iterator>
13#include <vector>
14
15namespace armnnOpaqueDelegate
16{
17
18constexpr unsigned int MaxNumOfTensorDimensions = 5U;
19
20TfLiteStatus VisitSplitOperator(DelegateData& delegateData,
21 TfLiteOpaqueContext* tfLiteContext,
22 TfLiteOpaqueNode* tfLiteNode,
23 int nodeIndex,
24 int32_t tfLiteSplitOperatorCode)
25{
26 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
27
28 auto* splitParameters = reinterpret_cast<TfLiteSplitParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
29 int numSplits = NonNegative(splitParameters->num_splits, nodeIndex);
30
31 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
32
33 // Gather input indices and use to get Axis tensor.
34 const int* inputTensors;
35 auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
36 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
37 {
38 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
39 tfLiteContext,
40 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
41 nodeIndex);
42 return kTfLiteError;
43 }
44
45 const TfLiteOpaqueTensor* tfLiteAxisTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
46 if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitOperatorCode, nodeIndex))
47 {
48 return kTfLiteError;
49 }
50
51 // Use input indices to get input tensor.
52 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
53 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitOperatorCode, nodeIndex))
54 {
55 return kTfLiteError;
56 }
57
58 // Gather output indices and use to get output tensors.
59 const int* outputTensors;
60 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numSplits) != kTfLiteOk)
61 {
62 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
63 tfLiteContext,
64 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
65 nodeIndex);
66 return kTfLiteError;
67 }
68
69 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
70
Ryan OSheac229b3f2023-06-27 22:34:54 +010071 if (GetTensorInfoForTfLiteOpaqueTensor(tfLiteAxisTensor).GetNumElements() != 1)
72 {
73 return kTfLiteError;
74 }
75
David Monahanc833cef2023-05-03 15:53:03 +010076 auto* axisTensorDataPtr = static_cast<uint32_t*>(TfLiteOpaqueTensorData(tfLiteAxisTensor));
77 std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
78 int32_t axis = axisTensorData[0];
79
80 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
81 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
82 {
83 // Square bracket denotes inclusive n while parenthesis denotes exclusive n
84 // E.g. Rank 4 tensor can have axis in range [-4, 3)
85 // -1 == 3, -2 == 2, -3 == 1, -4 == 0
86 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
87 tfLiteContext,
Ryan OShea59f8f652023-05-11 20:37:53 +010088 "TfLiteOpaqueArmnnDelegate: Operation has invalid axis: #%d. "
89 "Axis must be in range [-n, n) in node #%d:",
David Monahanc833cef2023-05-03 15:53:03 +010090 axis, nodeIndex);
91 }
92 const unsigned int splitDim = ComputeWrappedIndex(axis, inputTensorInfo.GetNumDimensions());
93
94 std::vector<armnn::TensorInfo> outputs;
95 for (int i = 0; i < numSplits; ++i)
96 {
97 const TfLiteOpaqueTensor* tfLiteOutputTensor =
98 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[i]);
99 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitOperatorCode, nodeIndex))
100 {
101 return kTfLiteError;
102 }
103 outputs.push_back(GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true));
104 }
105 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
106
107 auto inputDimSize = inputTensorInfo.GetNumDimensions();
108 if (inputDimSize > MaxNumOfTensorDimensions)
109 {
110 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
111 tfLiteContext,
Ryan OShea59f8f652023-05-11 20:37:53 +0100112 "TfLiteOpaqueArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be "
113 "greater than #%d in node #%d: ",
114 inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
David Monahanc833cef2023-05-03 15:53:03 +0100115 return kTfLiteError;
116 }
117
118 std::vector<unsigned int> splitterDimSizes(inputDimSize);
119
120 // Add current input shape to splitterDimSizes
121 for (unsigned int i = 0; i < inputDimSize; ++i)
122 {
123 splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
124 }
125
126 if (splitterDimSizes[splitDim] % numSplits != 0)
127 {
128 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
129 tfLiteContext,
Ryan OShea59f8f652023-05-11 20:37:53 +0100130 "TfLiteOpaqueArmnnDelegate: Number of splits #%d must evenly divide the dimension #%d in node #%d: ",
David Monahanc833cef2023-05-03 15:53:03 +0100131 numSplits, splitterDimSizes[splitDim], nodeIndex);
132 return kTfLiteError;
133 }
134 splitterDimSizes[splitDim] /= numSplits;
135
136 armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
Mike Kelly363b5722023-10-11 14:25:50 +0100137 splitDescriptor.SetAxis(axis);
138
David Monahanc833cef2023-05-03 15:53:03 +0100139 for (int j = 0; j < numSplits; ++j)
140 {
141 // Set the size of the views.
142 for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
143 {
144 splitDescriptor.SetViewSize(j, dimIdx, splitterDimSizes[dimIdx]);
145 }
146 splitDescriptor.SetViewOriginCoord(j, splitDim, splitterDimSizes[splitDim] * j);
147 }
148
149 armnn::BackendId setBackend;
150 if (!delegateData.m_Network)
151 {
152 // Check if supported
153 bool isSupported = false;
154 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("SPLIT",
155 tfLiteContext,
156 IsSplitterSupported,
157 delegateData.m_Backends,
158 isSupported,
159 setBackend,
160 inputTensorInfo,
161 outputTensorInfos,
162 splitDescriptor);
163 return isSupported ? kTfLiteOk : kTfLiteError;
164 }
165
Mike Kellya2806502023-08-03 10:42:11 +0100166 auto layerName = GetName(armnn::LayerType::Splitter, nodeIndex);
167 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor, layerName.c_str());
David Monahanc833cef2023-05-03 15:53:03 +0100168 layer->SetBackendId(setBackend);
169 ARMNN_ASSERT(layer != nullptr);
170
171 for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
172 {
173 layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
174 }
175
176 // Connect the input slots
Ryan OShea59f8f652023-05-11 20:37:53 +0100177 delegateData.m_OutputSlotForNode[inputTensors[1]]->Connect(layer->GetInputSlot(0));
178
179 if(numSplits != static_cast<int>(layer->GetNumOutputSlots()))
David Monahanc833cef2023-05-03 15:53:03 +0100180 {
Ryan OShea59f8f652023-05-11 20:37:53 +0100181 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
182 tfLiteContext,
183 "TfLiteOpaqueArmnnDelegate: Expected number of splits #%d does not "
184 "match the number of output slots #%d in node #%d: ",
185 numSplits, layer->GetNumOutputSlots(), nodeIndex);
David Monahanc833cef2023-05-03 15:53:03 +0100186 return kTfLiteError;
187 }
Ryan OShea59f8f652023-05-11 20:37:53 +0100188
189 // Prepare output slots
190 for (unsigned int outputIndex = 0; outputIndex < layer->GetNumOutputSlots(); ++outputIndex)
191 {
192 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
193 delegateData.m_OutputSlotForNode[
194 static_cast<unsigned long>(outputTensors[outputIndex])] = &outputSlot;
195 }
David Monahanc833cef2023-05-03 15:53:03 +0100196 return kTfLiteOk;
197}
198
199TfLiteStatus VisitSplitVOperator(DelegateData& delegateData,
200 TfLiteOpaqueContext* tfLiteContext,
201 TfLiteOpaqueNode* tfLiteNode,
202 int nodeIndex,
203 int32_t tfLiteSplitVOperatorCode)
204{
205
206 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
207
208 const int* inputTensors;
209 auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
210 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
211 {
212 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
213 tfLiteContext,
214 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
215 nodeIndex);
216 return kTfLiteError;
217 }
218
219 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
220 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitVOperatorCode, nodeIndex))
221 {
222 return kTfLiteError;
223 }
224
225 const TfLiteOpaqueTensor* tfLiteSplitsTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
226 if (!IsValid(tfLiteContext, tfLiteSplitsTensor, tfLiteSplitVOperatorCode, nodeIndex))
227 {
228 return kTfLiteError;
229 }
230
231 const TfLiteOpaqueTensor* tfLiteAxisTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[2]);
232 if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitVOperatorCode, nodeIndex))
233 {
234 return kTfLiteError;
235 }
236
237 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
238 const armnn::TensorInfo& splitsTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteSplitsTensor);
Ryan OSheac229b3f2023-06-27 22:34:54 +0100239
240 if (splitsTensorInfo.GetNumDimensions() != 1)
241 {
242 return kTfLiteError;
243 }
244
245 if (GetTensorInfoForTfLiteOpaqueTensor(tfLiteAxisTensor).GetNumElements() != 1)
246 {
247 return kTfLiteError;
248 }
David Monahanc833cef2023-05-03 15:53:03 +0100249
250 auto* axisTensorDataPtr = static_cast<uint32_t*>(TfLiteOpaqueTensorData(tfLiteAxisTensor));
251 std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
252 int32_t axis = axisTensorData[0];
253
254 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
255 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
256 {
257 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
258 tfLiteContext,
Ryan OShea59f8f652023-05-11 20:37:53 +0100259 "TfLiteOpaqueArmnnDelegate: Operation has invalid axis: #%d. "
260 "Axis must be in range [-n, n) in node #%d:",
David Monahanc833cef2023-05-03 15:53:03 +0100261 axis, nodeIndex);
262 }
263 const unsigned int splitDim = ComputeWrappedIndex(axisTensorData[0], inputTensorInfo.GetNumDimensions());
264
265 auto* splitVParameters = reinterpret_cast<TfLiteSplitVParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
266 int numSplits = 0;
267 if (splitVParameters)
268 {
269 numSplits = NonNegative(splitVParameters->num_splits, nodeIndex);
270 }
271 else
272 {
273 numSplits = splitsTensorInfo.GetNumElements();
274 }
275
276 if (numSplits <= 0)
277 {
278 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
Ryan OShea59f8f652023-05-11 20:37:53 +0100279 tfLiteContext,
280 "TfLiteOpaqueArmnnDelegate: Invalid number of splits %d in node #%d",
David Monahanc833cef2023-05-03 15:53:03 +0100281 numSplits, nodeIndex);
282 return kTfLiteError;
283 }
284
285 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
286
287 // Gather output indices and use to get output tensors.
288 const int* outputTensors;
289 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numSplits) != kTfLiteOk)
290 {
291 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
292 tfLiteContext,
293 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
294 nodeIndex);
295 return kTfLiteError;
296 }
297 std::vector<armnn::TensorInfo> outputs;
298 for (int i = 0; i < numSplits; ++i)
299 {
300 const TfLiteOpaqueTensor* tfLiteOutputTensor =
301 TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[i]);
302 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitVOperatorCode, nodeIndex))
303 {
304 return kTfLiteError;
305 }
306 outputs.push_back(GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true));
307 }
308 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
309
310 auto inputDimSize = inputTensorInfo.GetNumDimensions();
311 if (inputDimSize > MaxNumOfTensorDimensions)
312 {
313 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
314 tfLiteContext,
Ryan OShea59f8f652023-05-11 20:37:53 +0100315 "TfLiteOpaqueArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be "
316 "greater than #%d in node #%d: ",
317 inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
David Monahanc833cef2023-05-03 15:53:03 +0100318 return kTfLiteError;
319 }
320
321 std::vector<int32_t> splitsTensorData(numSplits);
322 std::memcpy(splitsTensorData.data(), TfLiteOpaqueTensorData(tfLiteSplitsTensor), splitsTensorInfo.GetNumBytes());
323
324
325 unsigned int index = 0;
326 unsigned int inferredIndex = 0;
327 int numberOfInferred = 0;
328 int splitSum = 0;
329
330 for (auto splitData : splitsTensorData)
331 {
332 if (splitData < 0)
333 {
334 ++numberOfInferred;
335 inferredIndex = index;
336 }
337 else
338 {
339 splitSum += splitData;
340 }
341 ++index;
342 }
343
344 // Check for inferred axis
345 if (numberOfInferred == 0)
346 {
347 if (splitSum != armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]))
348 {
349 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
Ryan OShea59f8f652023-05-11 20:37:53 +0100350 tfLiteContext,
351 "TfLiteOpaqueArmnnDelegate: SplitV split_sizes does not sum to the dimension "
352 "of value along split_dim in node #%d",
353 nodeIndex);
David Monahanc833cef2023-05-03 15:53:03 +0100354 return kTfLiteError;
355 }
356 }
357 else if (numberOfInferred == 1)
358 {
359 splitsTensorData[inferredIndex] = armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum;
360 }
361 else
362 {
363 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
Ryan OShea59f8f652023-05-11 20:37:53 +0100364 tfLiteContext,
365 "TfLiteOpaqueArmnnDelegate: SplitV cannot infer split size for "
366 "more than one split in node #%d",
David Monahanc833cef2023-05-03 15:53:03 +0100367 nodeIndex);
368 return kTfLiteError;
369 }
370
371 armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
Mike Kelly363b5722023-10-11 14:25:50 +0100372 splitDescriptor.SetAxis(axis);
David Monahanc833cef2023-05-03 15:53:03 +0100373 unsigned int accumSplit = 0;
Mike Kelly363b5722023-10-11 14:25:50 +0100374
David Monahanc833cef2023-05-03 15:53:03 +0100375 for (int j = 0; j < numSplits; ++j)
376 {
377 unsigned int splitSize = armnn::numeric_cast<unsigned int>(splitsTensorData[j]);
378
379 // Set the size of the views.
380 for (unsigned int dimIdx = 0; dimIdx < inputTensorInfo.GetNumDimensions(); ++dimIdx)
381 {
382 unsigned int dimSize = inputTensorInfo.GetShape()[dimIdx];
383 if (dimIdx == splitDim)
384 {
385 dimSize = splitSize;
386 }
387 splitDescriptor.SetViewSize(j, dimIdx, dimSize);
388 }
389
390 splitDescriptor.SetViewOriginCoord(j, splitDim, accumSplit);
391 accumSplit += splitSize;
392 }
393
394 armnn::BackendId setBackend;
395 if (!delegateData.m_Network)
396 {
397 // Check if supported
398 bool isSupported = false;
399 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("SPLITV",
400 tfLiteContext,
401 IsSplitterSupported,
402 delegateData.m_Backends,
403 isSupported,
404 setBackend,
405 inputTensorInfo,
406 outputTensorInfos,
407 splitDescriptor);
408 return isSupported ? kTfLiteOk : kTfLiteError;
409 }
410
Mike Kellya2806502023-08-03 10:42:11 +0100411 auto layerName = GetName(armnn::LayerType::Splitter, nodeIndex);
412 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor, layerName.c_str());
David Monahanc833cef2023-05-03 15:53:03 +0100413 layer->SetBackendId(setBackend);
414 ARMNN_ASSERT(layer != nullptr);
415
416 for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
417 {
418 layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
419 }
420
421 // try to connect the Constant Inputs if there are any
Mike Kellya2806502023-08-03 10:42:11 +0100422 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
David Monahanc833cef2023-05-03 15:53:03 +0100423 {
424 return kTfLiteError;
425 }
426
427 // Connect
428 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
429}
430
431} // namespace armnnOpaqueDelegate