blob: 998e3d3e149f5414b5f7e499e388b6910ea09842 [file] [log] [blame]
Cathal Corbett839b9322022-11-18 08:52:18 +00001//
Ryan OShea4c231de2023-01-17 15:19:20 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Cathal Corbett839b9322022-11-18 08:52:18 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/utility/IgnoreUnused.hpp>
9
10#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
18TfLiteStatus VisitStridedSliceOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 TfLiteNode* tfLiteNode,
21 int nodeIndex,
22 int32_t sliceOperatorCode)
23{
24 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 4, nodeIndex));
25 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26
27 // Read inputs [input, begin, end, strides]
28 int numInputs = tfLiteNode->inputs->size;
29 std::vector<const TfLiteTensor*> tfLiteInputs;
30 tfLiteInputs.reserve(numInputs);
31 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
32 for (int i = 0; i < numInputs; i++)
33 {
34 const TfLiteTensor* inputTensor = &tfLiteTensors[tfLiteNode->inputs->data[i]];
35 tfLiteInputs.push_back(inputTensor);
36 if (!IsValid(tfLiteContext, *inputTensor, sliceOperatorCode, nodeIndex))
37 {
38 return kTfLiteError;
39 }
40 }
41
42 // We save the begin, end and strides tensors in our descriptor. Therefore we have to read those values from inputs
43 int inputRank = tfLiteInputs[0]->dims->size;
44 auto ReadInt32Input = [&](int inputIndex, std::vector<int32_t>& outputData) -> TfLiteStatus
45 {
46 if (tfLiteInputs[inputIndex]->type != kTfLiteInt32)
47 {
48 TF_LITE_MAYBE_KERNEL_LOG(
49 tfLiteContext,
50 "TfLiteArmnnDelegate: The Begin-, End- and Stride-Tensors of the StridedSlice operation need to "
51 "be of type int32. Operator: #%d node #%d: ",
52 sliceOperatorCode, nodeIndex);
53 return kTfLiteError;
54 }
55 int rank = tfLiteInputs[inputIndex]->dims->size;
56 if (rank != 1)
57 {
58 TF_LITE_MAYBE_KERNEL_LOG(
59 tfLiteContext,
60 "TfLiteArmnnDelegate: The Begin-, End- and Stride-Tensors of the StridedSlice operation need to "
61 "be a 1D-Tensor. Operator: #%d node #%d: ",
62 sliceOperatorCode, nodeIndex);
63 return kTfLiteError;
64 }
65 int numValues = tfLiteInputs[inputIndex]->dims->data[0];
66 if (numValues != inputRank)
67 {
68 TF_LITE_MAYBE_KERNEL_LOG(
69 tfLiteContext,
70 "TfLiteArmnnDelegate: The number of values in the Begin-, End- and Stride-Tensors of the "
71 "StridedSlice operation need to be equal to the rank of the Input-Tensor. Operator: #%d node #%d: ",
72 sliceOperatorCode, nodeIndex);
73 return kTfLiteError;
74 }
75 // return tensor data
76 auto* tensorDataPtr = tflite::GetTensorData<int32_t>(tfLiteInputs[inputIndex]);
77 outputData.assign(tensorDataPtr, tensorDataPtr+numValues);
78 return kTfLiteOk;
79 };
80
81 std::vector<int32_t> beginData;
82 if (ReadInt32Input(1, beginData) != kTfLiteOk)
83 return kTfLiteError;
84 std::vector<int32_t> endData;
85 if (ReadInt32Input(2, endData) != kTfLiteOk)
86 return kTfLiteError;
87 std::vector<int32_t> strideData;
88 if (ReadInt32Input(3, strideData) != kTfLiteOk)
89 return kTfLiteError;
90
91 // parse built in options
92 auto* stridedSliceParams = reinterpret_cast<TfLiteStridedSliceParams*>(tfLiteNode->builtin_data);
93
94 // Write all data to the descriptor
95 armnn::StridedSliceDescriptor descriptor;
96 descriptor.m_Begin = std::move(beginData);
97 descriptor.m_End = std::move(endData);
98 descriptor.m_Stride = std::move(strideData);
99 descriptor.m_BeginMask = stridedSliceParams->begin_mask;
100 descriptor.m_EllipsisMask = stridedSliceParams->ellipsis_mask;
101 descriptor.m_EndMask = stridedSliceParams->end_mask;
102 descriptor.m_NewAxisMask = stridedSliceParams->new_axis_mask;
103 descriptor.m_ShrinkAxisMask = stridedSliceParams->shrink_axis_mask;
104 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
105
106 // Validate output
107 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
108 if (!IsValid(tfLiteContext, tfLiteOutputTensor, sliceOperatorCode, nodeIndex))
109 {
110 return kTfLiteError;
111 }
112
113 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(*tfLiteInputs[0]);
114 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
115
116 bool isSupported = false;
Cathal Corbett53837672022-09-01 11:34:37 +0100117 armnn::BackendId setBackend;
Cathal Corbett839b9322022-11-18 08:52:18 +0000118 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
119 {
120 FORWARD_LAYER_SUPPORT_FUNC("STRIDED_SLICE",
121 tfLiteContext,
122 IsStridedSliceSupported,
123 delegateData.m_Backends,
124 isSupported,
Cathal Corbett53837672022-09-01 11:34:37 +0100125 setBackend,
Cathal Corbett839b9322022-11-18 08:52:18 +0000126 inputTensorInfo,
127 outInfo,
128 descriptor);
129 };
130
131 if (!delegateData.m_Network)
132 {
133 validateFunc(outputTensorInfo, isSupported);
134 return isSupported ? kTfLiteOk : kTfLiteError;
135 }
136
137 // Add a StridedSlice layer
138 armnn::IConnectableLayer* layer = delegateData.m_Network->AddStridedSliceLayer(descriptor);
Cathal Corbett53837672022-09-01 11:34:37 +0100139 layer->SetBackendId(setBackend);
Cathal Corbett839b9322022-11-18 08:52:18 +0000140 ARMNN_ASSERT(layer != nullptr);
141
142 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
143 outputSlot.SetTensorInfo(outputTensorInfo);
144
Ryan OShea4c231de2023-01-17 15:19:20 +0000145 // try to connect the Constant Inputs if there are any
146 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
147 {
148 return kTfLiteError;
149 }
150
Cathal Corbett839b9322022-11-18 08:52:18 +0000151 // Connect
152 return Connect(layer, tfLiteNode, delegateData);
153}
154
155} // namespace armnnDelegate
Cathal Corbett53837672022-09-01 11:34:37 +0100156