blob: 2a67802028a566136033892d641018ddbaeaeda5 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
David Monahan6c53f9f2023-04-27 15:21:19 +01005#pragma once
6
7#include <OpaqueDelegateUtils.hpp>
8#include <MultiLayerFacade.hpp>
9
10
11namespace armnnOpaqueDelegate
12{
13
14TfLiteStatus ValidateAddOperator(DelegateData& delegateData,
15 TfLiteOpaqueContext* tfLiteContext,
16 const armnn::TensorInfo& inputInfo1,
17 const armnn::TensorInfo& inputInfo2,
18 const armnn::TensorInfo& outputInfo)
19{
20 bool isSupported = false;
21 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
22 {
23 std::vector<armnn::TensorInfo> infos { inputInfo1, inputInfo2, outputInfo };
24 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("ADD",
25 tfLiteContext,
26 IsElementwiseBinarySupported,
27 delegateData.m_Backends,
28 isSupported,
29 armnn::BackendId(),
30 inputInfo1,
31 inputInfo2,
32 outputInfo,
33 armnn::BinaryOperation::Add);
34 };
35
36 validateFunc(outputInfo, isSupported);
37 return isSupported ? kTfLiteOk : kTfLiteError;
38}
39
40
41TfLiteStatus ValidateDivOperator(DelegateData& delegateData,
42 TfLiteOpaqueContext* tfLiteContext,
43 const armnn::TensorInfo& inputInfo1,
44 const armnn::TensorInfo& inputInfo2,
45 const armnn::TensorInfo& outputInfo)
46{
47 bool isSupported = false;
48 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
49 {
50 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("DIV",
51 tfLiteContext,
52 IsElementwiseBinarySupported,
53 delegateData.m_Backends,
54 isSupported,
55 armnn::BackendId(),
56 inputInfo1,
57 inputInfo2,
58 outputTensorInfo,
59 armnn::BinaryOperation::Div);
60 };
61
62 validateFunc(outputInfo, isSupported);
63 return isSupported ? kTfLiteOk : kTfLiteError;
64}
65
66TfLiteStatus ValidateFloorDivOperator(DelegateData& delegateData,
67 TfLiteOpaqueContext* tfLiteContext,
68 const armnn::TensorInfo& inputInfo1,
69 const armnn::TensorInfo& inputInfo2,
70 const armnn::TensorInfo& outputInfo)
71{
72 // need first to validate that the div operator is supported
73 // then that the floor operator is supported
74 TfLiteStatus status = ValidateDivOperator(delegateData, tfLiteContext, inputInfo1, inputInfo2, outputInfo);
75 if (status != kTfLiteOk)
76 {
77 return status;
78 }
79 // if the inputs and output of the div are all Signed32 we don't need to add the floor operator afterward.
80 if (AreAllSigned32(inputInfo1, inputInfo2, outputInfo))
81 {
82 return status;
83 }
84 // in case broadcasting is being done from one of the inputs to the div
85 // choose the full sized input tensor to pass to the floor validation routine
86 armnn::TensorInfo floorInputInfo = inputInfo1;
87 if (inputInfo1.GetNumDimensions() < inputInfo2.GetNumDimensions())
88 {
89 floorInputInfo = inputInfo2;
90 }
91 status = ValidateFloorOperator(delegateData, tfLiteContext, floorInputInfo, outputInfo);
92 return status;
93}
94
95TfLiteStatus ValidateMaximumOperator(DelegateData& delegateData,
96 TfLiteOpaqueContext* tfLiteContext,
97 const armnn::TensorInfo& inputInfo1,
98 const armnn::TensorInfo& inputInfo2,
99 const armnn::TensorInfo& outputInfo)
100{
101 bool isSupported = false;
102 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
103 {
104 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("MAXIMUM",
105 tfLiteContext,
106 IsElementwiseBinarySupported,
107 delegateData.m_Backends,
108 isSupported,
109 armnn::BackendId(),
110 inputInfo1,
111 inputInfo2,
112 outputTensorInfo,
113 armnn::BinaryOperation::Maximum);
114 };
115
116 validateFunc(outputInfo, isSupported);
117 return isSupported ? kTfLiteOk : kTfLiteError;
118}
119
120TfLiteStatus ValidateMinimumOperator(DelegateData& delegateData,
121 TfLiteOpaqueContext* tfLiteContext,
122 const armnn::TensorInfo& inputInfo1,
123 const armnn::TensorInfo& inputInfo2,
124 const armnn::TensorInfo& outputInfo)
125{
126 bool isSupported = false;
127 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
128 {
129 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("MINIMUM",
130 tfLiteContext,
131 IsElementwiseBinarySupported,
132 delegateData.m_Backends,
133 isSupported,
134 armnn::BackendId(),
135 inputInfo1,
136 inputInfo2,
137 outputTensorInfo,
138 armnn::BinaryOperation::Minimum);
139 };
140
141 validateFunc(outputInfo, isSupported);
142 return isSupported ? kTfLiteOk : kTfLiteError;
143}
144
145TfLiteStatus ValidateMulOperator(DelegateData& delegateData,
146 TfLiteOpaqueContext* tfLiteContext,
147 const armnn::TensorInfo& inputInfo1,
148 const armnn::TensorInfo& inputInfo2,
149 const armnn::TensorInfo& outputInfo)
150{
151 bool isSupported = false;
152 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
153 {
154 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("MUL",
155 tfLiteContext,
156 IsElementwiseBinarySupported,
157 delegateData.m_Backends,
158 isSupported,
159 armnn::BackendId(),
160 inputInfo1,
161 inputInfo2,
162 outputTensorInfo,
163 armnn::BinaryOperation::Mul);
164 };
165
166 validateFunc(outputInfo, isSupported);
167 return isSupported ? kTfLiteOk : kTfLiteError;
168}
169
John Mcloughlin0ec00872023-05-15 17:03:49 +0100170TfLiteStatus ValidatePowerOperator(DelegateData& delegateData,
171 TfLiteOpaqueContext* tfLiteContext,
172 const armnn::TensorInfo& inputInfo1,
173 const armnn::TensorInfo& inputInfo2,
174 const armnn::TensorInfo& outputInfo)
175{
176 bool isSupported = false;
177 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
178 {
179 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("POWER",
180 tfLiteContext,
181 IsElementwiseBinarySupported,
182 delegateData.m_Backends,
183 isSupported,
184 armnn::BackendId(),
185 inputInfo1,
186 inputInfo2,
187 outputTensorInfo,
188 armnn::BinaryOperation::Power);
189 };
190
191 validateFunc(outputInfo, isSupported);
192 return isSupported ? kTfLiteOk : kTfLiteError;
193}
194
195TfLiteStatus ValidateSquaredDifferenceOperator(DelegateData& delegateData,
196 TfLiteOpaqueContext* tfLiteContext,
197 const armnn::TensorInfo& inputInfo1,
198 const armnn::TensorInfo& inputInfo2,
199 const armnn::TensorInfo& outputInfo)
200{
201 bool isSupported = false;
202 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
203 {
204 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("SQUAREDDIFFERENCE",
205 tfLiteContext,
206 IsElementwiseBinarySupported,
207 delegateData.m_Backends,
208 isSupported,
209 armnn::BackendId(),
210 inputInfo1,
211 inputInfo2,
212 outputTensorInfo,
213 armnn::BinaryOperation::SqDiff);
214 };
215
216 validateFunc(outputInfo, isSupported);
217 return isSupported ? kTfLiteOk : kTfLiteError;
218}
219
David Monahan6c53f9f2023-04-27 15:21:19 +0100220TfLiteStatus ValidateSubOperator(DelegateData& delegateData,
221 TfLiteOpaqueContext* tfLiteContext,
222 const armnn::TensorInfo& inputInfo1,
223 const armnn::TensorInfo& inputInfo2,
224 const armnn::TensorInfo& outputInfo)
225{
226 bool isSupported = false;
227 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
228 {
229 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("SUB",
230 tfLiteContext,
231 IsElementwiseBinarySupported,
232 delegateData.m_Backends,
233 isSupported,
234 armnn::BackendId(),
235 inputInfo1,
236 inputInfo2,
237 outputTensorInfo,
238 armnn::BinaryOperation::Sub);
239 };
240
241 validateFunc(outputInfo, isSupported);
242 return isSupported ? kTfLiteOk : kTfLiteError;
243}
244
245std::pair<armnn::IConnectableLayer*, armnn::IConnectableLayer*> AddFloorDivLayer(
246 DelegateData& delegateData,
Mike Kellya2806502023-08-03 10:42:11 +0100247 const armnn::TensorInfo& outputTensorInfo,
248 int nodeIndex)
David Monahan6c53f9f2023-04-27 15:21:19 +0100249{
Mike Kellya2806502023-08-03 10:42:11 +0100250 auto layerName = GetName(armnn::BinaryOperation::Div, nodeIndex);
David Monahan6c53f9f2023-04-27 15:21:19 +0100251 armnn::IConnectableLayer* divisionLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
Mike Kellya2806502023-08-03 10:42:11 +0100252 armnn::BinaryOperation::Div,
253 layerName.c_str());
254
David Monahan6c53f9f2023-04-27 15:21:19 +0100255 // if the output of the div is Signed32 the Floor layer is not required
256 if (armnn::DataType::Signed32 == outputTensorInfo.GetDataType())
257 {
258 return std::make_pair(divisionLayer, divisionLayer);
259 }
260 armnn::IOutputSlot& outputSlot = divisionLayer->GetOutputSlot(0);
261 outputSlot.SetTensorInfo(outputTensorInfo);
Mike Kellya2806502023-08-03 10:42:11 +0100262 auto floorName = GetName(armnn::LayerType::Floor, nodeIndex);
263 armnn::IConnectableLayer* floorLayer = delegateData.m_Network->AddFloorLayer(floorName.c_str());
David Monahan6c53f9f2023-04-27 15:21:19 +0100264 outputSlot.Connect(floorLayer->GetInputSlot(0));
265 return std::make_pair(divisionLayer, floorLayer);
266}
267
268TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
269 TfLiteOpaqueContext* tfLiteContext,
270 TfLiteOpaqueNode* tfLiteNode,
271 int nodeIndex,
272 int32_t elementwiseBinaryOperatorCode)
273{
274 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
275 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
276
277 // Gather input indices and use to get Input Tensors
278 auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
279 const int* inputTensors;
280 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
281 {
282 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
283 tfLiteContext,
284 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
285 nodeIndex);
286 return kTfLiteError;
287 }
288 const TfLiteOpaqueTensor* tfLiteInputTensor0 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
289 if (!IsValid(tfLiteContext, tfLiteInputTensor0, elementwiseBinaryOperatorCode, nodeIndex))
290 {
291 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
292 tfLiteContext,
293 "TfLiteArmnnOpaqueDelegate: Invalid input tensor in operator #%d node #%d: ",
294 elementwiseBinaryOperatorCode, nodeIndex);
295 return kTfLiteError;
296 }
297 // Use input indices to get filter tensor.
298 const TfLiteOpaqueTensor* tfLiteInputTensor1 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
299 if(!IsValid(tfLiteInputTensor1))
300 {
301 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
302 tfLiteContext,
303 "TfLiteArmnnOpaqueDelegate: Invalid input tensor in operator #%d node #%d: ",
304 elementwiseBinaryOperatorCode, nodeIndex);
305 return kTfLiteError;
306 }
307
308 // Gather output indices and use to get output tensors.
309 int numOutputs = 0;
310 const int* outputTensors;
311 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
312 {
313 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
314 tfLiteContext,
315 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
316 nodeIndex);
317 return kTfLiteError;
318 }
319 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
320 if (!IsValid(tfLiteContext, tfLiteOutputTensor, elementwiseBinaryOperatorCode, nodeIndex))
321 {
322 return kTfLiteError;
323 }
324
325 armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor0);
326 armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor1);
327 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
328
329
330
331 // Check if we need to expand the dims of the input tensor infos.
332 // This is required for a few of the backends.
333 if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
334 {
335 ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
336 }
337
338 auto* tfLiteNodeParameters = reinterpret_cast<TfLiteAddParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
339 TfLiteFusedActivation activationType = kTfLiteActNone;
340 if (tfLiteNodeParameters)
341 {
342 activationType = tfLiteNodeParameters->activation;
343 TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData,
344 tfLiteContext,
345 outputTensorInfo,
346 outputTensorInfo,
347 activationType);
348 if(activationStatus != kTfLiteOk)
349 {
350 return kTfLiteError;
351 }
352 }
353
354 if (!delegateData.m_Network)
355 {
356 switch(elementwiseBinaryOperatorCode)
357 {
358 case kTfLiteBuiltinAdd:
359 return ValidateAddOperator(delegateData,
360 tfLiteContext,
361 inputTensorInfo0,
362 inputTensorInfo1,
363 outputTensorInfo);
364 case kTfLiteBuiltinDiv:
365 return ValidateDivOperator(delegateData,
366 tfLiteContext,
367 inputTensorInfo0,
368 inputTensorInfo1,
369 outputTensorInfo);
370 case kTfLiteBuiltinFloorDiv:
371 return ValidateFloorDivOperator(delegateData,
372 tfLiteContext,
373 inputTensorInfo0,
374 inputTensorInfo1,
375 outputTensorInfo);
376 case kTfLiteBuiltinMaximum:
377 return ValidateMaximumOperator(delegateData,
378 tfLiteContext,
379 inputTensorInfo0,
380 inputTensorInfo1,
381 outputTensorInfo);
382 case kTfLiteBuiltinMinimum:
383 return ValidateMinimumOperator(delegateData,
384 tfLiteContext,
385 inputTensorInfo0,
386 inputTensorInfo1,
387 outputTensorInfo);
388 case kTfLiteBuiltinMul:
389 return ValidateMulOperator(delegateData,
390 tfLiteContext,
391 inputTensorInfo0,
392 inputTensorInfo1,
393 outputTensorInfo);
John Mcloughlin0ec00872023-05-15 17:03:49 +0100394 case kTfLiteBuiltinPow:
395 return ValidatePowerOperator(delegateData,
396 tfLiteContext,
397 inputTensorInfo0,
398 inputTensorInfo1,
399 outputTensorInfo);
400 case kTfLiteBuiltinSquaredDifference:
401 return ValidateSquaredDifferenceOperator(delegateData,
402 tfLiteContext,
403 inputTensorInfo0,
404 inputTensorInfo1,
405 outputTensorInfo);
David Monahan6c53f9f2023-04-27 15:21:19 +0100406 case kTfLiteBuiltinSub:
407 return ValidateSubOperator(delegateData,
408 tfLiteContext,
409 inputTensorInfo0,
410 inputTensorInfo1,
411 outputTensorInfo);
412 default:
413 return kTfLiteError;
414 }
415 }
416
417 armnn::IConnectableLayer* elementwiseBinaryLayer = nullptr;
418 armnnDelegate::MultiLayerFacade multiLayer;
Mike Kellya2806502023-08-03 10:42:11 +0100419 std::string layerName;
David Monahan6c53f9f2023-04-27 15:21:19 +0100420 switch(elementwiseBinaryOperatorCode)
421 {
422 case kTfLiteBuiltinAdd:
Mike Kellya2806502023-08-03 10:42:11 +0100423 layerName = GetName(armnn::BinaryOperation::Add, nodeIndex);
424 elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(armnn::BinaryOperation::Add,
425 layerName.c_str());
David Monahan6c53f9f2023-04-27 15:21:19 +0100426 break;
427 case kTfLiteBuiltinDiv:
Mike Kellya2806502023-08-03 10:42:11 +0100428 layerName = GetName(armnn::BinaryOperation::Div, nodeIndex);
429 elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(armnn::BinaryOperation::Div,
430 layerName.c_str());
David Monahan6c53f9f2023-04-27 15:21:19 +0100431 break;
432 case kTfLiteBuiltinFloorDiv:
433 {
Mike Kellya2806502023-08-03 10:42:11 +0100434 auto layers = AddFloorDivLayer(delegateData, outputTensorInfo, nodeIndex);
David Monahan6c53f9f2023-04-27 15:21:19 +0100435 multiLayer.AssignValues(layers.first, layers.second);
436 elementwiseBinaryLayer = &multiLayer;
437 }
438 break;
439 case kTfLiteBuiltinMaximum:
Mike Kellya2806502023-08-03 10:42:11 +0100440 layerName = GetName(armnn::BinaryOperation::Maximum, nodeIndex);
441 elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(armnn::BinaryOperation::Maximum,
442 layerName.c_str());
David Monahan6c53f9f2023-04-27 15:21:19 +0100443 break;
444 case kTfLiteBuiltinMinimum:
Mike Kellya2806502023-08-03 10:42:11 +0100445 layerName = GetName(armnn::BinaryOperation::Minimum, nodeIndex);
446 elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(armnn::BinaryOperation::Minimum,
447 layerName.c_str());
David Monahan6c53f9f2023-04-27 15:21:19 +0100448 break;
449 case kTfLiteBuiltinMul:
Mike Kellya2806502023-08-03 10:42:11 +0100450 layerName = GetName(armnn::BinaryOperation::Mul, nodeIndex);
451 elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(armnn::BinaryOperation::Mul,
452 layerName.c_str());
David Monahan6c53f9f2023-04-27 15:21:19 +0100453 break;
John Mcloughlin0ec00872023-05-15 17:03:49 +0100454 case kTfLiteBuiltinPow:
Mike Kellya2806502023-08-03 10:42:11 +0100455 layerName = GetName(armnn::BinaryOperation::Power, nodeIndex);
456 elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(armnn::BinaryOperation::Power,
457 layerName.c_str());
John Mcloughlin0ec00872023-05-15 17:03:49 +0100458 break;
459 case kTfLiteBuiltinSquaredDifference:
Mike Kellya2806502023-08-03 10:42:11 +0100460 layerName = GetName(armnn::BinaryOperation::SqDiff, nodeIndex);
461 elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(armnn::BinaryOperation::SqDiff,
462 layerName.c_str());
John Mcloughlin0ec00872023-05-15 17:03:49 +0100463 break;
David Monahan6c53f9f2023-04-27 15:21:19 +0100464 case kTfLiteBuiltinSub:
Mike Kellya2806502023-08-03 10:42:11 +0100465 layerName = GetName(armnn::BinaryOperation::Sub, nodeIndex);
466 elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(armnn::BinaryOperation::Sub,
467 layerName.c_str());
David Monahan6c53f9f2023-04-27 15:21:19 +0100468 break;
469 default:
470 return kTfLiteError;
471 }
472 ARMNN_ASSERT(elementwiseBinaryLayer != nullptr);
473 armnn::IOutputSlot& outputSlot = elementwiseBinaryLayer->GetOutputSlot(0);
474 outputSlot.SetTensorInfo(outputTensorInfo);
475
476 auto inputsTensorsProcess = ProcessInputs(elementwiseBinaryLayer,
477 delegateData,
478 tfLiteContext,
Mike Kellya2806502023-08-03 10:42:11 +0100479 tfLiteNode,
480 nodeIndex);
David Monahan6c53f9f2023-04-27 15:21:19 +0100481 if (inputsTensorsProcess == kTfLiteError)
482 {
483 return inputsTensorsProcess;
484 }
485
486 if(Connect(elementwiseBinaryLayer, tfLiteContext, tfLiteNode, delegateData) != kTfLiteOk)
487 {
488 return kTfLiteError;
489 }
490
491 if (!tfLiteNodeParameters)
492 {
493 // No Activation
494 return kTfLiteOk;
495 }
496 // Check and Create Activation
Mike Kellya2806502023-08-03 10:42:11 +0100497 return FusedActivation(tfLiteContext, tfLiteNode, activationType, elementwiseBinaryLayer, 0, delegateData,
498 nodeIndex);
David Monahan6c53f9f2023-04-27 15:21:19 +0100499}
500
501} // namespace armnnOpaqueDelegate