blob: 2064b2e7e4c7cfc3c8935a6d8b77ceb82d70c5b7 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Teresa Charlin86b03572023-04-28 13:19:12 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12
13TfLiteStatus VisitSliceOperator(DelegateData& delegateData,
14 TfLiteOpaqueContext* tfLiteContext,
15 TfLiteOpaqueNode* tfLiteNode,
16 int nodeIndex,
17 int32_t tfLiteSliceOperatorCode)
18{
19
20 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
21 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
22
23 // Read inputs [input, begin, size]
24 // Gather input indices and use to get input tensor.
25 const int* inputTensors;
26 int numInputs;
27 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
28 {
29 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
30 tfLiteContext,
31 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
32 nodeIndex);
33 return kTfLiteError;
34 }
35
36 std::vector<const TfLiteOpaqueTensor*> tfLiteInputTensors;
37 tfLiteInputTensors.reserve(numInputs);
38 for (int i = 0; i < numInputs; i++)
39 {
40 const TfLiteOpaqueTensor* inputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[i]);
41 tfLiteInputTensors.push_back(inputTensor);
42 if (!IsValid(tfLiteContext, inputTensor, tfLiteSliceOperatorCode, nodeIndex))
43 {
44 return kTfLiteError;
45 }
46 }
47
48 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensors[0]);
49
50 // We save the begin and size tensors in our descriptor. Therefore we have to read those values from inputs
51 unsigned int inputRank = inputTensorInfo.GetNumDimensions();
52 auto ReadInt32Input = [&](int inputIndex, std::vector<uint32_t>& outputData) -> TfLiteStatus
53 {
54 if (TfLiteOpaqueTensorType(tfLiteInputTensors[inputIndex]) != kTfLiteInt32)
55 {
56 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
57 tfLiteContext,
58 "TfLiteArmnnOpaqueDelegate: The Begin- and Size-Tensors of the Slice operation need to "
59 "be of type int32. Operator: #%d node #%d: ",
60 tfLiteSliceOperatorCode, nodeIndex);
61 return kTfLiteError;
62 }
63 uint32_t rank = TfLiteOpaqueTensorNumDims(tfLiteInputTensors[inputIndex]);
64 if (rank != 1)
65 {
66 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
67 tfLiteContext,
68 "TfLiteArmnnOpaqueDelegate: The Begin- and Size-Tensors of the Slice operation need to "
69 "be a 1D-Tensor. Operator: #%d node #%d: ",
70 tfLiteSliceOperatorCode, nodeIndex);
71 return kTfLiteError;
72 }
73 uint32_t numValues = TfLiteOpaqueTensorDim(tfLiteInputTensors[inputIndex], 0);
74 if (numValues != inputRank)
75 {
76 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
77 tfLiteContext,
78 "TfLiteArmnnOpaqueDelegate: The number of values in the Begin- and Size-Tensors of the "
79 "Slice operation need to be equal to the rank of the Input-Tensor. Operator: #%d node #%d: ",
80 tfLiteSliceOperatorCode, nodeIndex);
81 return kTfLiteError;
82 }
83 // return tensor data
84 auto* tensorDataPtr = static_cast<uint32_t*>(TfLiteOpaqueTensorData(tfLiteInputTensors[inputIndex]));
85 outputData.assign(tensorDataPtr, tensorDataPtr + numValues);
86 return kTfLiteOk;
87 };
88
89 std::vector<uint32_t> begin;
90 if (ReadInt32Input(1, begin) != kTfLiteOk)
91 return kTfLiteError;
92 std::vector<uint32_t> size;
93 if (ReadInt32Input(2, size) != kTfLiteOk)
94 return kTfLiteError;
95
96 // Write all data to the descriptor
97 armnn::SliceDescriptor descriptor(begin, size);
98
99 // Validate output
100 // Gather output indices and use to get output tensor.
101 const int* outputTensors;
102 int numOutputs;
103 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
104 {
105 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
106 tfLiteContext,
107 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
108 nodeIndex);
109 return kTfLiteError;
110 }
111
112 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
113 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSliceOperatorCode, nodeIndex))
114 {
115 return kTfLiteError;
116 }
117
118 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
119
120 bool isSupported = false;
121 armnn::BackendId setBackend;
122 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
123 {
124 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("SLICE",
125 tfLiteContext,
126 IsSliceSupported,
127 delegateData.m_Backends,
128 isSupported,
129 setBackend,
130 inputTensorInfo,
131 outInfo,
132 descriptor);
133 };
134
135 if (!delegateData.m_Network)
136 {
137 validateFunc(outputTensorInfo, isSupported);
138 return isSupported ? kTfLiteOk : kTfLiteError;
139 }
140
141 // Add a Slice layer
142 armnn::IConnectableLayer* layer = delegateData.m_Network->AddSliceLayer(descriptor);
143 layer->SetBackendId(setBackend);
144 ARMNN_ASSERT(layer != nullptr);
145
146 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
147 outputSlot.SetTensorInfo(outputTensorInfo);
148
149 // try to connect the Constant Inputs if there are any
150 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
151 {
152 return kTfLiteError;
153 }
154
155 // Connect
156 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
157}
158
159} // namespace armnnOpaqueDelegate
160