blob: a49b768873f657f969cdd60994ded40848197704 [file] [log] [blame]
Teresa Charlind5c0ed22022-04-25 18:23:41 +01001//
Ryan OSheaa544f0f2023-01-25 18:10:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Teresa Charlind5c0ed22022-04-25 18:23:41 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matthew Sloyan11572322023-03-16 10:17:51 +00008#include <ClassicDelegateUtils.hpp>
9
Teresa Charlind5c0ed22022-04-25 18:23:41 +010010#include <algorithm>
11#include <iterator>
12#include <string>
13#include <vector>
14
15namespace armnnDelegate
16{
17TfLiteStatus VisitGatherNdOperator(DelegateData& delegateData,
18 TfLiteContext* tfLiteContext,
19 TfLiteNode* tfLiteNode,
20 int nodeIndex,
21 int32_t operatorCode)
22{
23 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25
26 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27
28 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
29 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
30 {
31 return kTfLiteError;
32 }
33
34 const TfLiteTensor& tfLiteIndicesTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
35 if (!IsValid(tfLiteContext, tfLiteIndicesTensor, operatorCode, nodeIndex))
36 {
37 return kTfLiteError;
38 }
39
40 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
41 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
42 {
43 return kTfLiteError;
44 }
45
46 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
47 const armnn::TensorInfo& indicesTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteIndicesTensor);
Sadik Armagan90a119b2022-08-05 16:12:49 +010048 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
Teresa Charlind5c0ed22022-04-25 18:23:41 +010049
Cathal Corbett53837672022-09-01 11:34:37 +010050 armnn::BackendId setBackend;
Teresa Charlind5c0ed22022-04-25 18:23:41 +010051 if (!delegateData.m_Network)
52 {
53 // Check if supported
54 bool isSupported = false;
55 FORWARD_LAYER_SUPPORT_FUNC("GATHER_ND",
56 tfLiteContext,
57 IsGatherNdSupported,
58 delegateData.m_Backends,
59 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +010060 setBackend,
Teresa Charlind5c0ed22022-04-25 18:23:41 +010061 inputTensorInfo,
62 indicesTensorInfo,
63 outputTensorInfo);
64 return isSupported ? kTfLiteOk : kTfLiteError;
65 }
66
Mike Kelly07169c82023-08-02 13:23:09 +010067 auto layerName = GetLayerName(armnn::LayerType::GatherNd, nodeIndex);
68 armnn::IConnectableLayer* layer = delegateData.m_Network->AddGatherNdLayer(layerName.c_str());
Cathal Corbett53837672022-09-01 11:34:37 +010069 layer->SetBackendId(setBackend);
Teresa Charlind5c0ed22022-04-25 18:23:41 +010070 ARMNN_ASSERT(layer != nullptr);
71 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
72
73 auto inputsTensorsProcess = ProcessInputs(layer,
74 delegateData,
75 tfLiteContext,
Mike Kelly07169c82023-08-02 13:23:09 +010076 tfLiteNode,
77 nodeIndex);
Teresa Charlind5c0ed22022-04-25 18:23:41 +010078 if (inputsTensorsProcess == kTfLiteError)
79 {
80 return inputsTensorsProcess;
81 }
82
Ryan OSheaa544f0f2023-01-25 18:10:20 +000083 return Connect(layer, tfLiteNode, delegateData);
Teresa Charlind5c0ed22022-04-25 18:23:41 +010084}
85} // namespace armnnDelegate