blob: a237034bb6c682b5837713a2e33eaa305a18bd78 [file] [log] [blame]
Sadik Armagan62483be2020-10-23 17:14:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Finn Williams6f9f9902020-11-13 13:23:15 +00008#include <armnn/utility/IgnoreUnused.hpp>
9
Sadik Armagan62483be2020-10-23 17:14:43 +010010#include <tensorflow/lite/builtin_ops.h>
11#include <tensorflow/lite/c/builtin_op_data.h>
12#include <tensorflow/lite/c/common.h>
13#include <tensorflow/lite/minimal_logging.h>
14
15namespace armnnDelegate
16{
17
18TfLiteStatus VisitSliceOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 TfLiteNode* tfLiteNode,
21 int nodeIndex,
22 int32_t sliceOperatorCode)
23{
Jan Eilers2ffddda2021-02-03 09:14:30 +000024 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 4, nodeIndex));
25 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
Finn Williams6f9f9902020-11-13 13:23:15 +000026
Jan Eilers2ffddda2021-02-03 09:14:30 +000027 // Read inputs [input, begin, end, strides]
28 int numInputs = tfLiteNode->inputs->size;
29 std::vector<const TfLiteTensor*> tfLiteInputs;
30 tfLiteInputs.reserve(numInputs);
31 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
32 for (int i = 0; i < numInputs; i++)
33 {
34 const TfLiteTensor* inputTensor = &tfLiteTensors[tfLiteNode->inputs->data[i]];
35 tfLiteInputs.push_back(inputTensor);
36 if (!IsValid(tfLiteContext, *inputTensor, sliceOperatorCode, nodeIndex))
37 {
38 return kTfLiteError;
39 }
40 }
41
42 // We save the begin, end and strides tensors in our descriptor. Therefore we have to read those values from inputs
43 int inputRank = tfLiteInputs[0]->dims->size;
44 auto ReadInt32Input = [&](int inputIndex, std::vector<int32_t>& outputData) -> TfLiteStatus
45 {
46 if (tfLiteInputs[inputIndex]->type != kTfLiteInt32)
47 {
48 TF_LITE_MAYBE_KERNEL_LOG(
49 tfLiteContext,
50 "TfLiteArmnnDelegate: The Begin-, End- and Stride-Tensors of the StridedSlice operation need to "
51 "be of type int32. Operator: #%d node #%d: ",
52 sliceOperatorCode, nodeIndex);
53 return kTfLiteError;
54 }
55 int rank = tfLiteInputs[inputIndex]->dims->size;
56 if (rank != 1)
57 {
58 TF_LITE_MAYBE_KERNEL_LOG(
59 tfLiteContext,
60 "TfLiteArmnnDelegate: The Begin-, End- and Stride-Tensors of the StridedSlice operation need to "
61 "be a 1D-Tensor. Operator: #%d node #%d: ",
62 sliceOperatorCode, nodeIndex);
63 return kTfLiteError;
64 }
65 int numValues = tfLiteInputs[inputIndex]->dims->data[0];
66 if (numValues != inputRank)
67 {
68 TF_LITE_MAYBE_KERNEL_LOG(
69 tfLiteContext,
70 "TfLiteArmnnDelegate: The number of values in the Begin-, End- and Stride-Tensors of the "
71 "StridedSlice operation need to be equal to the rank of the Input-Tensor. Operator: #%d node #%d: ",
72 sliceOperatorCode, nodeIndex);
73 return kTfLiteError;
74 }
75 // return tensor data
76 auto* tensorDataPtr = tflite::GetTensorData<int32_t>(tfLiteInputs[inputIndex]);
77 outputData.assign(tensorDataPtr, tensorDataPtr+numValues);
78 return kTfLiteOk;
79 };
80
81 std::vector<int32_t> beginData;
82 if (ReadInt32Input(1, beginData) != kTfLiteOk)
83 return kTfLiteError;
84 std::vector<int32_t> endData;
85 if (ReadInt32Input(2, endData) != kTfLiteOk)
86 return kTfLiteError;
87 std::vector<int32_t> strideData;
88 if (ReadInt32Input(3, strideData) != kTfLiteOk)
89 return kTfLiteError;
90
91 // parse built in options
92 auto* stridedSliceParams = reinterpret_cast<TfLiteStridedSliceParams*>(tfLiteNode->builtin_data);
93
94 // Write all data to the descriptor
95 armnn::StridedSliceDescriptor descriptor;
96 descriptor.m_Begin = std::move(beginData);
97 descriptor.m_End = std::move(endData);
98 descriptor.m_Stride = std::move(strideData);
99 descriptor.m_BeginMask = stridedSliceParams->begin_mask;
100 descriptor.m_EllipsisMask = stridedSliceParams->ellipsis_mask;
101 descriptor.m_EndMask = stridedSliceParams->end_mask;
102 descriptor.m_NewAxisMask = stridedSliceParams->new_axis_mask;
103 descriptor.m_ShrinkAxisMask = stridedSliceParams->shrink_axis_mask;
104 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
105
106 // Validate output
107 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
108 if (!IsValid(tfLiteContext, tfLiteOutputTensor, sliceOperatorCode, nodeIndex))
109 {
110 return kTfLiteError;
111 }
112
113 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(*tfLiteInputs[0]);
114 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
115
116 bool isSupported = false;
117 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
118 {
119 FORWARD_LAYER_SUPPORT_FUNC(__func__,
120 tfLiteContext,
121 IsStridedSliceSupported,
122 delegateData.m_Backends,
123 isSupported,
124 inputTensorInfo,
125 outInfo,
126 descriptor);
127 };
128
129 if (!delegateData.m_Network)
130 {
131 validateFunc(outputTensorInfo, isSupported);
132 return isSupported ? kTfLiteOk : kTfLiteError;
133 }
134
135 // Add a StridedSlice layer
136 armnn::IConnectableLayer* layer = delegateData.m_Network->AddStridedSliceLayer(descriptor);
137 ARMNN_ASSERT(layer != nullptr);
138
139 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
140 outputSlot.SetTensorInfo(outputTensorInfo);
141
142 // Connect
143 return Connect(layer, tfLiteNode, delegateData);
Sadik Armagan62483be2020-10-23 17:14:43 +0100144}
145
146} // namespace armnnDelegate