blob: 08bbed7f0ec68cfebb19975e8b01ed05edaa87f5 [file] [log] [blame]
Kevin May93bbf002024-03-11 09:31:10 +00001//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12TfLiteStatus ValidateScatterNdOperator(DelegateData& delegateData,
13 TfLiteOpaqueContext *tfLiteContext,
14 const armnn::TensorInfo& indicesInfo,
15 const armnn::TensorInfo& updatesInfo,
16 const armnn::TensorInfo& shapeInfo,
17 const armnn::TensorInfo& outputInfo,
18 const armnn::ScatterNdDescriptor& descriptor)
19{
20 bool isSupported = false;
21 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("SCATTER_ND",
22 tfLiteContext,
23 IsScatterNdSupported,
24 delegateData.m_Backends,
25 isSupported,
26 armnn::BackendId(),
27 shapeInfo,
28 indicesInfo,
29 updatesInfo,
30 outputInfo,
31 descriptor);
32 return isSupported ? kTfLiteOk : kTfLiteError;
33}
34
35TfLiteStatus VisitScatterNdOperator(DelegateData& delegateData,
36 TfLiteOpaqueContext* tfLiteContext,
37 TfLiteOpaqueNode* tfLiteNode,
38 int nodeIndex,
39 int32_t scatterNdOperatorCode)
40{
41 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
42 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
43
44 // Gather input indices and use to get input tensor.
45 auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
46 const int* inputTensors;
47 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
48 {
49 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
50 tfLiteContext,
51 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
52 nodeIndex);
53 return kTfLiteError;
54 }
55
56 // Gather input indices and use to get output tensor.
57 int numOutputs = 0;
58 const int* outputTensors;
59 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
60 {
61 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
62 tfLiteContext,
63 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
64 nodeIndex);
65 return kTfLiteError;
66 }
67
68 // The indices tensor are the positions the data is updated/scattered into
69 const TfLiteOpaqueTensor* tfLiteIndicesTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
70 if (IsDynamicTensor(tfLiteIndicesTensor))
71 {
72 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
73 tfLiteContext,
74 "TfLiteArmnnOpaqueDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
75 scatterNdOperatorCode, nodeIndex);
76 return kTfLiteError;
77 }
78
79 // The updates tensor provides the data which will be updated/scattered into the relevant indices
80 const TfLiteOpaqueTensor* tfLiteUpdatesTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
81 if (IsDynamicTensor(tfLiteUpdatesTensor))
82 {
83 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
84 tfLiteContext,
85 "TfLiteArmnnOpaqueDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
86 scatterNdOperatorCode, nodeIndex);
87 return kTfLiteError;
88 }
89
90 // For TFLite ScatterNd there is no input tensor
91 // The shape tensor is a 1D tensor which represents the shape of an input tensor to be filled with zeros
92 const TfLiteOpaqueTensor* tfLiteShapeTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[2]);
93 if (IsDynamicTensor(tfLiteShapeTensor))
94 {
95 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
96 tfLiteContext,
97 "TfLiteArmnnOpaqueDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
98 scatterNdOperatorCode, nodeIndex);
99 return kTfLiteError;
100 }
101
102 // The output tensor
103 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
104 if (IsDynamicTensor(tfLiteOutputTensor))
105 {
106 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
107 tfLiteContext,
108 "TfLiteArmnnOpaqueDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
109 scatterNdOperatorCode, nodeIndex);
110 return kTfLiteError;
111 }
112
113 const armnn::TensorInfo& shapeTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteShapeTensor);
114 const armnn::TensorInfo& indicesTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteIndicesTensor);
115 const armnn::TensorInfo& updatesTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteUpdatesTensor);
116 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
117
118 armnn::ScatterNdDescriptor scatterNdDescriptor;
119 scatterNdDescriptor.m_Function = armnn::ScatterNdFunction::Update;
120 scatterNdDescriptor.m_InputEnabled = false;
121 scatterNdDescriptor.m_Axis = 0;
122 scatterNdDescriptor.m_AxisEnabled = false;
123
124 // Check output dimensions
125 if (shapeTensorInfo.GetShape().GetNumElements() != outputTensorInfo.GetNumDimensions())
126 {
127 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
128 tfLiteContext,
129 "TfLiteArmnnOpaqueDelegate: Input tensor dimension and output tensor dimension differ",
130 "Operator: #%d node #%d: ",
131 scatterNdOperatorCode, nodeIndex);
132 return kTfLiteError;
133 }
134
135 // No network pointer indicates that only support for this operator should be checked
136 if (!delegateData.m_Network)
137 {
138 return ValidateScatterNdOperator(delegateData,
139 tfLiteContext,
140 indicesTensorInfo,
141 updatesTensorInfo,
142 shapeTensorInfo,
143 outputTensorInfo,
144 scatterNdDescriptor);
145 }
146
147 auto layerName = GetName(armnn::LayerType::ScatterNd, nodeIndex);
148 armnn::IConnectableLayer* layer = delegateData.m_Network->AddScatterNdLayer(scatterNdDescriptor, layerName.c_str());
149
150 if (layer == nullptr)
151 {
152 return kTfLiteError;
153 }
154
155 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
156
157 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
158 {
159 return kTfLiteError;
160 }
161
162 delegateData.m_OutputSlotForNode[inputTensors[2]]->Connect(layer->GetInputSlot(0));
163 delegateData.m_OutputSlotForNode[inputTensors[0]]->Connect(layer->GetInputSlot(1));
164 delegateData.m_OutputSlotForNode[inputTensors[1]]->Connect(layer->GetInputSlot(2));
165
166 // Prepare output slots
167 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
168 delegateData.m_OutputSlotForNode[static_cast<unsigned long>(outputTensors[0])] = &outputSlot;
169
170 return kTfLiteOk;
171}
172
173} // namespace armnnOpaqueDelegate