blob: 8248be9413651130449b9e7bfc99a33e30c7d5a1 [file] [log] [blame]
Sadik Armagan34fa1bd2020-11-27 12:40:52 +00001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "DelegateUtils.hpp"
9
10#include <algorithm>
11#include <iterator>
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000012#include <vector>
13
14namespace armnnDelegate
15{
16
17constexpr unsigned int MaxNumOfTensorDimensions = 5U;
18
19TfLiteStatus VisitSplitOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t tfLiteSplitOperatorCode)
24{
25 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
26
27 auto* splitParameters = reinterpret_cast<TfLiteSplitParams*>(tfLiteNode->builtin_data);
28 const unsigned int numSplits = NonNegative(splitParameters->num_splits, nodeIndex);
29
30 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
31
32 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
33 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
34 if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitOperatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
40 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitOperatorCode, nodeIndex))
41 {
42 return kTfLiteError;
43 }
44
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000045 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
46
Finn Williams019840d2020-11-30 17:43:28 +000047 ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
Sadik Armagan34fa1bd2020-11-27 12:40:52 +000048 auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
49 std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
50 const unsigned int splitDim = axisTensorData[0];
51
52 std::vector<armnn::TensorInfo> outputs;
53 for (unsigned int i = 0; i < numSplits; ++i)
54 {
55 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
56 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitOperatorCode, nodeIndex))
57 {
58 return kTfLiteError;
59 }
60 outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor));
61 }
62 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
63
64 auto inputDimSize = inputTensorInfo.GetNumDimensions();
65 if (inputDimSize > MaxNumOfTensorDimensions)
66 {
67 TF_LITE_MAYBE_KERNEL_LOG(
68 tfLiteContext,
69 "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
70 "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
71 return kTfLiteError;
72 }
73
74 std::vector<unsigned int> splitterDimSizes(inputDimSize);
75
76 // Add current input shape to splitterDimSizes
77 for (unsigned int i = 0; i < inputDimSize; ++i)
78 {
79 splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
80 }
81
82 if (splitterDimSizes[splitDim] % numSplits != 0)
83 {
84 TF_LITE_MAYBE_KERNEL_LOG(
85 tfLiteContext,
86 "TfLiteArmnnDelegate: Number of splits #%d must evenly divide the dimension #%d in node #%d: ",
87 numSplits, splitterDimSizes[splitDim], nodeIndex);
88 return kTfLiteError;
89 }
90 splitterDimSizes[splitDim] /= numSplits;
91
92 armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
93 for (unsigned int j = 0; j < numSplits; ++j)
94 {
95 // Set the size of the views.
96 for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
97 {
98 splitDescriptor.SetViewSize(j, dimIdx, splitterDimSizes[dimIdx]);
99 }
100 splitDescriptor.SetViewOriginCoord(j, splitDim, splitterDimSizes[splitDim] * j);
101 }
102
103 if (!delegateData.m_Network)
104 {
105 // Check if supported
106 bool isSupported = false;
107 FORWARD_LAYER_SUPPORT_FUNC(__func__,
108 tfLiteContext,
109 IsSplitterSupported,
110 delegateData.m_Backends,
111 isSupported,
112 inputTensorInfo,
113 outputTensorInfos,
114 splitDescriptor);
115 return isSupported ? kTfLiteOk : kTfLiteError;
116 }
117
118 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
119 ARMNN_ASSERT(layer != nullptr);
120
121 for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
122 {
123 layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
124 }
125
126 // Connect the input slots
127 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(0));
128
129 // Prepare output slots
130 for (unsigned int outputIndex = 0; outputIndex < layer->GetNumOutputSlots(); ++outputIndex)
131 {
132 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
133 delegateData.m_OutputSlotForNode[
134 static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &outputSlot;
135 }
136
137 return kTfLiteOk;
138}
139
140TfLiteStatus VisitSplitVOperator(DelegateData& delegateData,
141 TfLiteContext* tfLiteContext,
142 TfLiteNode* tfLiteNode,
143 int nodeIndex,
144 int32_t tfLiteSplitVOperatorCode)
145{
146 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
147
148 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
149 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
150 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitVOperatorCode, nodeIndex))
151 {
152 return kTfLiteError;
153 }
154
155 const TfLiteTensor& tfLiteSplitsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
156 if (!IsValid(tfLiteContext, tfLiteSplitsTensor, tfLiteSplitVOperatorCode, nodeIndex))
157 {
158 return kTfLiteError;
159 }
160
161 const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
162 if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitVOperatorCode, nodeIndex))
163 {
164 return kTfLiteError;
165 }
166
167 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
168 const armnn::TensorInfo& splitsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteSplitsTensor);
169 ARMNN_ASSERT(splitsTensorInfo.GetNumDimensions() == 1);
Finn Williams019840d2020-11-30 17:43:28 +0000170 ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000171
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000172 auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
173 std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
174
175 auto ComputeWrappedIndex = [](int index, unsigned int numDimensions)
176 {
177 int numDims = armnn::numeric_cast<int>(numDimensions);
178 int wrappedIndex = index < 0 ? numDims + index : index;
179 ARMNN_ASSERT(wrappedIndex >= 0);
180 ARMNN_ASSERT(wrappedIndex < numDims);
181
182 return static_cast<unsigned int>(wrappedIndex);
183 };
184
185 const unsigned int splitDim = ComputeWrappedIndex(axisTensorData[0],
186 inputTensorInfo.GetNumDimensions());
187
188 auto* splitVParameters = reinterpret_cast<TfLiteSplitVParams*>(tfLiteNode->builtin_data);
189 unsigned int numSplits = 0;
190 if (splitVParameters)
191 {
192 numSplits = NonNegative(splitVParameters->num_splits, nodeIndex);
193 }
194 else
195 {
196 numSplits = splitsTensorInfo.GetNumElements();
197 }
198
199 if (numSplits <= 0)
200 {
201 TF_LITE_MAYBE_KERNEL_LOG(
202 tfLiteContext, "TfLiteArmnnDelegate: Invalid number of splits %d in node #%d",
203 numSplits, nodeIndex);
204 return kTfLiteError;
205 }
206
207 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
208 std::vector<armnn::TensorInfo> outputs;
209 for (unsigned int i = 0; i < numSplits; ++i)
210 {
211 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
212 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitVOperatorCode, nodeIndex))
213 {
214 return kTfLiteError;
215 }
216 outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor));
217 }
218 const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
219
220 auto inputDimSize = inputTensorInfo.GetNumDimensions();
221 if (inputDimSize > MaxNumOfTensorDimensions)
222 {
223 TF_LITE_MAYBE_KERNEL_LOG(
224 tfLiteContext,
225 "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
226 "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
227 return kTfLiteError;
228 }
229
230 std::vector<int32_t> splitsTensorData(numSplits);
David Monahanc11ba462020-12-03 11:09:46 +0000231 std::memcpy(splitsTensorData.data(), tfLiteSplitsTensor.data.data, splitsTensorInfo.GetNumBytes());
232
Sadik Armagan34fa1bd2020-11-27 12:40:52 +0000233
234 unsigned int index = 0;
235 unsigned int inferredIndex = 0;
236 int numberOfInferred = 0;
237 int splitSum = 0;
238
239 for (auto splitData : splitsTensorData)
240 {
241 if (splitData < 0)
242 {
243 ++numberOfInferred;
244 inferredIndex = index;
245 }
246 else
247 {
248 splitSum += splitData;
249 }
250 ++index;
251 }
252
253 // Check for inferred axis
254 if (numberOfInferred == 0)
255 {
256 if (splitSum != armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]))
257 {
258 TF_LITE_MAYBE_KERNEL_LOG(
259 tfLiteContext, "TfLiteArmnnDelegate: SplitV split_sizes does not sum to the dimension of value along"
260 " split_dim in node #%d", nodeIndex);
261 return kTfLiteError;
262 }
263 }
264 else if (numberOfInferred == 1)
265 {
266 splitsTensorData[inferredIndex] = armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum;
267 }
268 else
269 {
270 TF_LITE_MAYBE_KERNEL_LOG(
271 tfLiteContext, "TfLiteArmnnDelegate: SplitV cannot infer split size for more than one split in node #%d",
272 nodeIndex);
273 return kTfLiteError;
274 }
275
276 armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
277 unsigned int accumSplit = 0;
278 for (unsigned int j = 0; j < numSplits; ++j)
279 {
280 unsigned int splitSize = armnn::numeric_cast<unsigned int>(splitsTensorData[j]);
281
282 // Set the size of the views.
283 for (unsigned int dimIdx = 0; dimIdx < inputTensorInfo.GetNumDimensions(); ++dimIdx)
284 {
285 unsigned int dimSize = inputTensorInfo.GetShape()[dimIdx];
286 if (dimIdx == splitDim)
287 {
288 dimSize = splitSize;
289 }
290 splitDescriptor.SetViewSize(j, dimIdx, dimSize);
291 }
292
293 splitDescriptor.SetViewOriginCoord(j, splitDim, accumSplit);
294 accumSplit += splitSize;
295 }
296
297 if (!delegateData.m_Network)
298 {
299 // Check if supported
300 bool isSupported = false;
301 FORWARD_LAYER_SUPPORT_FUNC(__func__,
302 tfLiteContext,
303 IsSplitterSupported,
304 delegateData.m_Backends,
305 isSupported,
306 inputTensorInfo,
307 outputTensorInfos,
308 splitDescriptor);
309 return isSupported ? kTfLiteOk : kTfLiteError;
310 }
311
312 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
313 ARMNN_ASSERT(layer != nullptr);
314
315 for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
316 {
317 layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
318 }
319
320 // Connect
321 return Connect(layer, tfLiteNode, delegateData);
322}
323
324} // namespace armnnDelegate