blob: c73e231c4612dc324c8c55802c98eaea761585af [file] [log] [blame]
Kevin May93bbf002024-03-11 09:31:10 +00001//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <tensorflow/lite/builtin_ops.h>
9#include <tensorflow/lite/c/builtin_op_data.h>
10#include <tensorflow/lite/c/common.h>
11#include <tensorflow/lite/minimal_logging.h>
12#include <tensorflow/lite/kernels/internal/tensor_ctypes.h>
13#include <tensorflow/lite/schema/schema_generated.h>
14
15namespace armnnDelegate
16{
17TfLiteStatus ValidateScatterNdOperator(DelegateData& delegateData,
18 TfLiteContext* tfLiteContext,
19 const armnn::TensorInfo& indicesInfo,
20 const armnn::TensorInfo& updatesInfo,
21 const armnn::TensorInfo& shapeInfo,
22 const armnn::TensorInfo& outputInfo,
23 const armnn::ScatterNdDescriptor& descriptor)
24{
25 bool isSupported = false;
26 FORWARD_LAYER_SUPPORT_FUNC("SCATTER_ND",
27 tfLiteContext,
28 IsScatterNdSupported,
29 delegateData.m_Backends,
30 isSupported,
31 armnn::BackendId(),
32 shapeInfo,
33 indicesInfo,
34 updatesInfo,
35 outputInfo,
36 descriptor);
37 return isSupported ? kTfLiteOk : kTfLiteError;
38}
39
40TfLiteStatus VisitScatterNdOperator(DelegateData& delegateData,
41 TfLiteContext* tfLiteContext,
42 TfLiteNode* tfLiteNode,
43 int nodeIndex,
44 int32_t scatterNdOperatorCode)
45{
46 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
47 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
48
49 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
50
51 // The indices tensor are the positions the data is updated/scattered into
52 const TfLiteTensor& tfLiteIndicesTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
53 if (IsDynamicTensor(tfLiteIndicesTensor))
54 {
55 TF_LITE_MAYBE_KERNEL_LOG(
56 tfLiteContext,
57 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
58 scatterNdOperatorCode, nodeIndex);
59 return kTfLiteError;
60 }
61
62 // The updates tensor provides the data which will be updated/scattered into the relevant indices
63 const TfLiteTensor& tfLiteUpdatesTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
64 if (IsDynamicTensor(tfLiteUpdatesTensor))
65 {
66 TF_LITE_MAYBE_KERNEL_LOG(
67 tfLiteContext,
68 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
69 scatterNdOperatorCode, nodeIndex);
70 return kTfLiteError;
71 }
72
73 // For tflite scatternd there is no input tensor
74 // The shape tensor is a 1D tensor which represents the shape of an input tensor to be filled with zeros
75 const TfLiteTensor& tfLiteShapeTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
76 if (IsDynamicTensor(tfLiteUpdatesTensor))
77 {
78 TF_LITE_MAYBE_KERNEL_LOG(
79 tfLiteContext,
80 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
81 scatterNdOperatorCode, nodeIndex);
82 return kTfLiteError;
83 }
84
85 // The output tensor
86 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
87 if (IsDynamicTensor(tfLiteOutputTensor))
88 {
89 TF_LITE_MAYBE_KERNEL_LOG(
90 tfLiteContext,
91 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
92 scatterNdOperatorCode, nodeIndex);
93 return kTfLiteError;
94 }
95
96 const armnn::TensorInfo& indicesTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteIndicesTensor);
97 const armnn::TensorInfo& updatesTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteUpdatesTensor);
98 const armnn::TensorInfo& shapeTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteShapeTensor);
99 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
100
101 armnn::ScatterNdDescriptor scatterNdDescriptor;
102 scatterNdDescriptor.m_Function = armnn::ScatterNdFunction::Update;
103 scatterNdDescriptor.m_InputEnabled = false;
104 scatterNdDescriptor.m_Axis = 0;
105 scatterNdDescriptor.m_AxisEnabled = false;
106
107 // Check output dimensions
108 if (shapeTensorInfo.GetShape().GetNumElements() != outputTensorInfo.GetNumDimensions())
109 {
110 TF_LITE_MAYBE_KERNEL_LOG(
111 tfLiteContext,
112 "TfLiteArmnnDelegate: Shape tensor number of elements and output tensor dimension differ",
113 "Operator: #%d node #%d: ",
114 scatterNdOperatorCode, nodeIndex);
115 return kTfLiteError;
116 }
117
118 // No network pointer indicates that only support for this operator should be checked
119 if (!delegateData.m_Network)
120 {
121 return ValidateScatterNdOperator(delegateData,
122 tfLiteContext,
123 indicesTensorInfo,
124 updatesTensorInfo,
125 shapeTensorInfo,
126 outputTensorInfo,
127 scatterNdDescriptor);
128 }
129
130 auto layerName = GetLayerName(armnn::LayerType::ScatterNd, nodeIndex);
131 armnn::IConnectableLayer* layer = delegateData.m_Network->AddScatterNdLayer(scatterNdDescriptor, layerName.c_str());
132
133 if (layer == nullptr)
134 {
135 return kTfLiteError;
136 }
137
138 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
139
140 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
141 {
142 return kTfLiteError;
143 }
144
145 if (static_cast<unsigned int>(tfLiteNode->outputs->size) != layer->GetNumOutputSlots())
146 {
147 return kTfLiteError;
148 }
149
150 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[2]]->Connect(layer->GetInputSlot(0));
151 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(layer->GetInputSlot(1));
152 delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(2));
153
154 // Prepare output slots
155 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
156 delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tfLiteNode->outputs->data[0])] = &outputSlot;
157
158 return kTfLiteOk;
159}
160
161} // namespace armnnDelegate