blob: a767d01ad4082c03b14929cde74d9dfa8b2dbb5e [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Kevin Mayb2831c52023-04-26 17:27:24 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
Kevin Mayb2831c52023-04-26 17:27:24 +010010namespace armnnOpaqueDelegate
11{
12TfLiteStatus VisitGatherNdOperator(DelegateData& delegateData,
13 TfLiteOpaqueContext* tfLiteContext,
14 TfLiteOpaqueNode* tfLiteNode,
15 int nodeIndex,
16 int32_t operatorCode)
17{
18 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
19 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
20
21 int numInputs = 0;
22 const int* inputTensors;
23 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
24 {
25 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
26 tfLiteContext,
27 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
28 nodeIndex);
29 return kTfLiteError;
30 }
31
32 int numOutputs = 0;
33 const int* outputTensors;
34 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
35 {
36 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
37 tfLiteContext,
38 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
39 nodeIndex);
40 return kTfLiteError;
41 }
42
43 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
44 inputTensors[0]);
45 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
46 {
47 return kTfLiteError;
48 }
49
50 const TfLiteOpaqueTensor* tfLiteIndicesTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
51 inputTensors[1]);
52 if (!IsValid(tfLiteContext, tfLiteIndicesTensor, operatorCode, nodeIndex))
53 {
54 return kTfLiteError;
55 }
56
57 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
58 outputTensors[0]);
59 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
60 {
61 return kTfLiteError;
62 }
63
64 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
65 const armnn::TensorInfo& indicesTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteIndicesTensor);
66 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
67
68 armnn::BackendId setBackend;
69 if (!delegateData.m_Network)
70 {
71 // Check if supported
72 bool isSupported = false;
73 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("GATHER_ND",
74 tfLiteContext,
75 IsGatherNdSupported,
76 delegateData.m_Backends,
77 isSupported,
78 setBackend,
79 inputTensorInfo,
80 indicesTensorInfo,
81 outputTensorInfo);
82 return isSupported ? kTfLiteOk : kTfLiteError;
83 }
84
85 armnn::IConnectableLayer* layer = delegateData.m_Network->AddGatherNdLayer();
86 layer->SetBackendId(setBackend);
87 ARMNN_ASSERT(layer != nullptr);
88 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
89
90 auto inputsTensorsProcess = ProcessInputs(layer,
91 delegateData,
92 tfLiteContext,
93 tfLiteNode);
94 if (inputsTensorsProcess == kTfLiteError)
95 {
96 return inputsTensorsProcess;
97 }
98
99 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
100}
101} // namespace armnnOpaqueDelegate