blob: 2e17e3292f82a881eb09e817fc8a653a11ffd2d8 [file] [log] [blame]
Francis Murtaghc4fb0dd2023-03-16 17:01:56 +00001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Teresa Charlin86b03572023-04-28 13:19:12 +01005
6#pragma once
7
8#include <OpaqueDelegateUtils.hpp>
9
10namespace armnnOpaqueDelegate
11{
12
13TfLiteStatus VisitStridedSliceOperator(DelegateData& delegateData,
14 TfLiteOpaqueContext* tfLiteContext,
15 TfLiteOpaqueNode* tfLiteNode,
16 int nodeIndex,
17 int32_t tfLiteStridedSliceOperatorCode)
18{
19 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 4, nodeIndex));
20 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
21
22 // Read inputs [input, begin, end, strides]
23 // Gather input indices and use to get input tensor.
24 const int* inputTensors;
25 int numInputs;
26 if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
27 {
28 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
29 tfLiteContext,
30 "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
31 nodeIndex);
32 return kTfLiteError;
33 }
34
35 std::vector<const TfLiteOpaqueTensor*> tfLiteInputTensors;
36 tfLiteInputTensors.reserve(numInputs);
37 for (int i = 0; i < numInputs; i++)
38 {
39 const TfLiteOpaqueTensor* inputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[i]);
40 tfLiteInputTensors.push_back(inputTensor);
41 if (!IsValid(tfLiteContext, inputTensor, tfLiteStridedSliceOperatorCode, nodeIndex))
42 {
43 return kTfLiteError;
44 }
45 }
46
47 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensors[0]);
48
49 // We save the begin, end and strides tensors in our descriptor. Therefore we have to read those values from inputs
50 unsigned int inputRank = inputTensorInfo.GetNumDimensions();
51 auto ReadInt32Input = [&](int inputIndex, std::vector<int32_t>& outputData) -> TfLiteStatus
52 {
53 if (TfLiteOpaqueTensorType(tfLiteInputTensors[inputIndex]) != kTfLiteInt32)
54 {
55 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
56 tfLiteContext,
57 "TfLitearmnnOpaqueDelegate: The Begin-, End- and Stride-Tensors of the StridedSlice operation need"
58 " to be of type int32. Operator: #%d node #%d: ",
59 tfLiteStridedSliceOperatorCode, nodeIndex);
60 return kTfLiteError;
61 }
62 uint32_t rank = TfLiteOpaqueTensorNumDims(tfLiteInputTensors[inputIndex]);
63 if (rank != 1)
64 {
65 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
66 tfLiteContext,
67 "TfLitearmnnOpaqueDelegate: The Begin-, End- and Stride-Tensors of the StridedSlice operation need"
68 " to be a 1D-Tensor. Operator: #%d node #%d: ",
69 tfLiteStridedSliceOperatorCode, nodeIndex);
70 return kTfLiteError;
71 }
72 uint32_t numValues = TfLiteOpaqueTensorDim(tfLiteInputTensors[inputIndex], 0);
73 if (numValues != inputRank)
74 {
75 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
76 tfLiteContext,
77 "TfLitearmnnOpaqueDelegate: The number of values in the Begin-, End- and Stride-Tensors of the "
78 "StridedSlice operation need to be equal to the rank of the Input-Tensor. Operator: #%d node #%d: ",
79 tfLiteStridedSliceOperatorCode, nodeIndex);
80 return kTfLiteError;
81 }
82 // return tensor data
83 auto* tensorDataPtr = static_cast<uint32_t*>(TfLiteOpaqueTensorData(tfLiteInputTensors[inputIndex]));
84 outputData.assign(tensorDataPtr, tensorDataPtr + numValues);
85 return kTfLiteOk;
86 };
87
88 std::vector<int32_t> beginData;
89 if (ReadInt32Input(1, beginData) != kTfLiteOk)
90 return kTfLiteError;
91 std::vector<int32_t> endData;
92 if (ReadInt32Input(2, endData) != kTfLiteOk)
93 return kTfLiteError;
94 std::vector<int32_t> strideData;
95 if (ReadInt32Input(3, strideData) != kTfLiteOk)
96 return kTfLiteError;
97
98 // parse built in options
99 auto* nodeParameters = reinterpret_cast<TfLiteStridedSliceParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
100
101 // Write all data to the descriptor
102 armnn::StridedSliceDescriptor descriptor;
103 descriptor.m_Begin = std::move(beginData);
104 descriptor.m_End = std::move(endData);
105 descriptor.m_Stride = std::move(strideData);
106 descriptor.m_BeginMask = nodeParameters->begin_mask;
107 descriptor.m_EllipsisMask = nodeParameters->ellipsis_mask;
108 descriptor.m_EndMask = nodeParameters->end_mask;
109 descriptor.m_NewAxisMask = nodeParameters->new_axis_mask;
110 descriptor.m_ShrinkAxisMask = nodeParameters->shrink_axis_mask;
111 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
112
113 // Validate output
114 // Gather output indices and use to get output tensor.
115 const int* outputTensors;
116 int numOutputs;
117 if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
118 {
119 TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
120 tfLiteContext,
121 "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
122 nodeIndex);
123 return kTfLiteError;
124 }
125
126 const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
127 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteStridedSliceOperatorCode, nodeIndex))
128 {
129 return kTfLiteError;
130 }
131
132 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor);
133
134 bool isSupported = false;
135 armnn::BackendId setBackend;
136 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
137 {
138 FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("STRIDED_SLICE",
139 tfLiteContext,
140 IsStridedSliceSupported,
141 delegateData.m_Backends,
142 isSupported,
143 setBackend,
144 inputTensorInfo,
145 outInfo,
146 descriptor);
147 };
148
149 if (!delegateData.m_Network)
150 {
151 validateFunc(outputTensorInfo, isSupported);
152 return isSupported ? kTfLiteOk : kTfLiteError;
153 }
154
155 // Add a StridedSlice layer
Mike Kellya2806502023-08-03 10:42:11 +0100156 auto layerName = GetName(armnn::LayerType::StridedSlice, nodeIndex);
157 armnn::IConnectableLayer* layer = delegateData.m_Network->AddStridedSliceLayer(descriptor, layerName.c_str());
Teresa Charlin86b03572023-04-28 13:19:12 +0100158 layer->SetBackendId(setBackend);
159 ARMNN_ASSERT(layer != nullptr);
160
161 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
162 outputSlot.SetTensorInfo(outputTensorInfo);
163
164 // try to connect the Constant Inputs if there are any
Mike Kellya2806502023-08-03 10:42:11 +0100165 if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode, nodeIndex) != kTfLiteOk)
Teresa Charlin86b03572023-04-28 13:19:12 +0100166 {
167 return kTfLiteError;
168 }
169
170 // Connect
171 return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
172}
173
174} // namespace armnnOpaqueDelegate
175