blob: 4c9cf8283253dd2c4fdcf69a95f8ecd125a52563 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Ryan OSheaa544f0f2023-01-25 18:10:20 +00002// Copyright © 2020,2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Teresa Charlinad1b3d72023-03-14 12:10:28 +00008#include <DelegateUtils.hpp>
Teresa Charlin98427a12020-11-25 18:22:57 +00009#include <algorithm>
10#include <iterator>
11#include <string>
12#include <vector>
Sadik Armagan62483be2020-10-23 17:14:43 +010013
14namespace armnnDelegate
15{
Sadik Armagan62483be2020-10-23 17:14:43 +010016TfLiteStatus VisitGatherOperator(DelegateData& delegateData,
17 TfLiteContext* tfLiteContext,
18 TfLiteNode* tfLiteNode,
19 int nodeIndex,
Teresa Charlin98427a12020-11-25 18:22:57 +000020 int32_t operatorCode)
Sadik Armagan62483be2020-10-23 17:14:43 +010021{
Teresa Charlin98427a12020-11-25 18:22:57 +000022 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
23 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +000024
Teresa Charlin98427a12020-11-25 18:22:57 +000025 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
26
27 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
28 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
29 {
30 return kTfLiteError;
31 }
32
33 const TfLiteTensor& tfLiteIndicesTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
34 if (!IsValid(tfLiteContext, tfLiteIndicesTensor, operatorCode, nodeIndex))
35 {
36 return kTfLiteError;
37 }
38
39 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
40 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
41 {
42 return kTfLiteError;
43 }
44
45 auto* gatherParameters = reinterpret_cast<TfLiteGatherParams*>(tfLiteNode->builtin_data);
46 auto axis = gatherParameters->axis;
47
48 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
49 const armnn::TensorInfo& indicesTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteIndicesTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010050 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Teresa Charlin98427a12020-11-25 18:22:57 +000051 armnn::GatherDescriptor gatherDescriptor;
52 gatherDescriptor.m_Axis = axis;
53
54 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
55 auto indicesDimensions = indicesTensorInfo.GetNumDimensions();
56 auto outputDimensions = outputTensorInfo.GetNumDimensions();
57 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
58 {
59 TF_LITE_MAYBE_KERNEL_LOG( tfLiteContext,
60 "TfLiteArmnnDelegate: Operation has invalid axis: %d. It is out of bounds [-%d, %d))",
61 axis, inputDimensions, inputDimensions);
62 return kTfLiteError;
63 }
64 if (outputDimensions != static_cast<unsigned int>(inputDimensions) + indicesDimensions - 1)
65 {
66 TF_LITE_MAYBE_KERNEL_LOG( tfLiteContext,
67 "Operation has invalid output dimensions: %d. Output must be an (%d + %d - 1)-D tensor",
68 outputDimensions, inputDimensions, indicesDimensions);
69 return kTfLiteError;
70 }
71
Cathal Corbett53837672022-09-01 11:34:37 +010072 armnn::BackendId setBackend;
Teresa Charlin98427a12020-11-25 18:22:57 +000073 if (!delegateData.m_Network)
74 {
75 // Check if supported
76 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +000077 FORWARD_LAYER_SUPPORT_FUNC("GATHER",
Teresa Charlin98427a12020-11-25 18:22:57 +000078 tfLiteContext,
79 IsGatherSupported,
80 delegateData.m_Backends,
81 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010082 setBackend,
Teresa Charlin98427a12020-11-25 18:22:57 +000083 inputTensorInfo,
84 indicesTensorInfo,
85 outputTensorInfo,
86 gatherDescriptor);
87 return isSupported ? kTfLiteOk : kTfLiteError;
88 }
89
90 armnn::IConnectableLayer* layer = delegateData.m_Network->AddGatherLayer(gatherDescriptor);
Cathal Corbett53837672022-09-01 11:34:37 +010091 layer->SetBackendId(setBackend);
Teresa Charlin98427a12020-11-25 18:22:57 +000092 ARMNN_ASSERT(layer != nullptr);
Teresa Charlin98427a12020-11-25 18:22:57 +000093 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
94
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010095 auto inputsTensorsProcess = ProcessInputs(layer,
96 delegateData,
97 tfLiteContext,
98 tfLiteNode);
99 if (inputsTensorsProcess == kTfLiteError)
100 {
101 return inputsTensorsProcess;
102 }
103
Ryan OSheaa544f0f2023-01-25 18:10:20 +0000104 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100105}
Teresa Charlin98427a12020-11-25 18:22:57 +0000106} // namespace armnnDelegate