blob: b27016e06a27be7e9cb50c6a00360335bcc6bb3d [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Kevin Mayb2831c52023-04-26 17:27:24 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
Kevin Mayb2831c52023-04-26 17:27:24 +010010namespace armnnOpaqueDelegate
11{
12
Teresa Charlinf69ae562023-04-27 14:42:23 +010013TfLiteStatus VisitGatherOperator(DelegateData& delegateData,
14 TfLiteOpaqueContext* tfLiteContext,
15 TfLiteOpaqueNode* tfLiteNode,
16 int nodeIndex,
17 int32_t operatorCode)
18{
19 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
20 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
21
22 int numInputs = 0;
23 const int* inputTensors;
24 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
Kevin Mayb2831c52023-04-26 17:27:24 +010025 {
Teresa Charlinf69ae562023-04-27 14:42:23 +010026 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
27 tfLiteContext,
28 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
29 nodeIndex);
30 return kTfLiteError;
Kevin Mayb2831c52023-04-26 17:27:24 +010031 }
Teresa Charlinf69ae562023-04-27 14:42:23 +010032
33 int numOutputs = 0;
34 const int* outputTensors;
35 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
36 {
37 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
38 tfLiteContext,
39 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
40 nodeIndex);
41 return kTfLiteError;
42 }
43
44 const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
45 inputTensors[0]);
46 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
47 {
48 return kTfLiteError;
49 }
50
51 const TfLiteOpaqueTensor* tfLiteIndicesTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
52 inputTensors[1]);
53 if (!IsValid(tfLiteContext, tfLiteIndicesTensor, operatorCode, nodeIndex))
54 {
55 return kTfLiteError;
56 }
57
58 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext,
59 outputTensors[0]);
60 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
61 {
62 return kTfLiteError;
63 }
64 auto* tfLiteNodeParameters = reinterpret_cast<TfLiteGatherParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
65 auto axis = tfLiteNodeParameters->axis;
66
67 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
68 const armnn::TensorInfo& indicesTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteIndicesTensor);
69 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
70 armnn::GatherDescriptor gatherDescriptor;
71 gatherDescriptor.m_Axis = axis;
72
73 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
74 auto indicesDimensions = indicesTensorInfo.GetNumDimensions();
75 auto outputDimensions = outputTensorInfo.GetNumDimensions();
76 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
77 {
78 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
79 tfLiteContext,
80 "TfLiteArmnnOpaqueDelegate: Operation has invalid axis: %d. It is out of bounds [-%d, %d))",
81 axis, inputDimensions, inputDimensions);
82 return kTfLiteError;
83 }
84 if (outputDimensions != static_cast<unsigned int>(inputDimensions) + indicesDimensions - 1)
85 {
86 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
87 tfLiteContext,
88 "TfLiteArmnnOpaqueDelegate: Operation has invalid output dimensions: %d. "
89 "Output must be an (%d + %d - 1)-D tensor",
90 outputDimensions, inputDimensions, indicesDimensions);
91 return kTfLiteError;
92 }
93
94 armnn::BackendId setBackend;
95 if (!delegateData.m_Network)
96 {
97 // Check if supported
98 bool isSupported = false;
99 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("GATHER",
100 tfLiteContext,
101 IsGatherSupported,
102 delegateData.m_Backends,
103 isSupported,
104 setBackend,
105 inputTensorInfo,
106 indicesTensorInfo,
107 outputTensorInfo,
108 gatherDescriptor);
109 return isSupported ? kTfLiteOk : kTfLiteError;
110 }
111
112 armnn::IConnectableLayer* layer = delegateData.m_Network->AddGatherLayer(gatherDescriptor);
113 layer->SetBackendId(setBackend);
114 ARMNN_ASSERT(layer != nullptr);
115 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
116
117 auto inputsTensorsProcess = ProcessInputs(layer,
118 delegateData,
119 tfLiteContext,
120 tfLiteNode);
121 if (inputsTensorsProcess == kTfLiteError)
122 {
123 return inputsTensorsProcess;
124 }
125
126 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
127}
Kevin Mayb2831c52023-04-26 17:27:24 +0100128} // namespace armnnOpaqueDelegate