blob: 30dbd0dc0e03a6797e726fc95a2f17e086e6590b [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
Ryan OSheaa544f0f2023-01-25 18:10:20 +00002// Copyright © 2020,2022-2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan62483be2020-10-23 17:14:43 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
9
Teresa Charlin98427a12020-11-25 18:22:57 +000010#include <algorithm>
11#include <iterator>
12#include <string>
13#include <vector>
Sadik Armagan62483be2020-10-23 17:14:43 +010014
15namespace armnnDelegate
16{
Sadik Armagan62483be2020-10-23 17:14:43 +010017TfLiteStatus VisitGatherOperator(DelegateData& delegateData,
18 TfLiteContext* tfLiteContext,
19 TfLiteNode* tfLiteNode,
20 int nodeIndex,
Teresa Charlin98427a12020-11-25 18:22:57 +000021 int32_t operatorCode)
Sadik Armagan62483be2020-10-23 17:14:43 +010022{
Teresa Charlin98427a12020-11-25 18:22:57 +000023 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +000025
Teresa Charlin98427a12020-11-25 18:22:57 +000026 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27
28 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
29 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
30 {
31 return kTfLiteError;
32 }
33
34 const TfLiteTensor& tfLiteIndicesTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
35 if (!IsValid(tfLiteContext, tfLiteIndicesTensor, operatorCode, nodeIndex))
36 {
37 return kTfLiteError;
38 }
39
40 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
41 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
42 {
43 return kTfLiteError;
44 }
45
46 auto* gatherParameters = reinterpret_cast<TfLiteGatherParams*>(tfLiteNode->builtin_data);
47 auto axis = gatherParameters->axis;
48
49 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
50 const armnn::TensorInfo& indicesTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteIndicesTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010051 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Teresa Charlin98427a12020-11-25 18:22:57 +000052 armnn::GatherDescriptor gatherDescriptor;
53 gatherDescriptor.m_Axis = axis;
54
55 auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
56 auto indicesDimensions = indicesTensorInfo.GetNumDimensions();
57 auto outputDimensions = outputTensorInfo.GetNumDimensions();
58 if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
59 {
60 TF_LITE_MAYBE_KERNEL_LOG( tfLiteContext,
61 "TfLiteArmnnDelegate: Operation has invalid axis: %d. It is out of bounds [-%d, %d))",
62 axis, inputDimensions, inputDimensions);
63 return kTfLiteError;
64 }
65 if (outputDimensions != static_cast<unsigned int>(inputDimensions) + indicesDimensions - 1)
66 {
67 TF_LITE_MAYBE_KERNEL_LOG( tfLiteContext,
68 "Operation has invalid output dimensions: %d. Output must be an (%d + %d - 1)-D tensor",
69 outputDimensions, inputDimensions, indicesDimensions);
70 return kTfLiteError;
71 }
72
Cathal Corbett53837672022-09-01 11:34:37 +010073 armnn::BackendId setBackend;
Teresa Charlin98427a12020-11-25 18:22:57 +000074 if (!delegateData.m_Network)
75 {
76 // Check if supported
77 bool isSupported = false;
Sadik Armaganbfa767c2022-02-09 14:58:03 +000078 FORWARD_LAYER_SUPPORT_FUNC("GATHER",
Teresa Charlin98427a12020-11-25 18:22:57 +000079 tfLiteContext,
80 IsGatherSupported,
81 delegateData.m_Backends,
82 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010083 setBackend,
Teresa Charlin98427a12020-11-25 18:22:57 +000084 inputTensorInfo,
85 indicesTensorInfo,
86 outputTensorInfo,
87 gatherDescriptor);
88 return isSupported ? kTfLiteOk : kTfLiteError;
89 }
90
Mike Kelly07169c82023-08-02 13:23:09 +010091 auto layerName = GetLayerName(armnn::LayerType::Gather, nodeIndex);
92 armnn::IConnectableLayer* layer = delegateData.m_Network->AddGatherLayer(gatherDescriptor, layerName.c_str());
Cathal Corbett53837672022-09-01 11:34:37 +010093 layer->SetBackendId(setBackend);
Teresa Charlin98427a12020-11-25 18:22:57 +000094 ARMNN_ASSERT(layer != nullptr);
Teresa Charlin98427a12020-11-25 18:22:57 +000095 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
96
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010097 auto inputsTensorsProcess = ProcessInputs(layer,
98 delegateData,
99 tfLiteContext,
Mike Kelly07169c82023-08-02 13:23:09 +0100100 tfLiteNode,
101 nodeIndex);
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100102 if (inputsTensorsProcess == kTfLiteError)
103 {
104 return inputsTensorsProcess;
105 }
106
Ryan OSheaa544f0f2023-01-25 18:10:20 +0000107 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100108}
Teresa Charlin98427a12020-11-25 18:22:57 +0000109} // namespace armnnDelegate