blob: 740dafe878a6ac413f82fab32ae02faf83c67182 [file] [log] [blame]
Cathal Corbett839b9322022-11-18 08:52:18 +00001//
Colm Donelan7bcae3c2024-01-22 10:07:14 +00002// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
Cathal Corbett839b9322022-11-18 08:52:18 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Cathal Corbett839b9322022-11-18 08:52:18 +000012
Cathal Corbett839b9322022-11-18 08:52:18 +000013#include <tensorflow/lite/version.h>
14
Cathal Corbett839b9322022-11-18 08:52:18 +000015namespace
16{
17
18std::vector<char> CreateStridedSliceTfLiteModel(tflite::TensorType tensorType,
19 const std::vector<int32_t>& inputTensorShape,
20 const std::vector<int32_t>& beginTensorData,
21 const std::vector<int32_t>& endTensorData,
22 const std::vector<int32_t>& strideTensorData,
23 const std::vector<int32_t>& beginTensorShape,
24 const std::vector<int32_t>& endTensorShape,
25 const std::vector<int32_t>& strideTensorShape,
26 const std::vector<int32_t>& outputTensorShape,
27 const int32_t beginMask,
28 const int32_t endMask,
29 const int32_t ellipsisMask,
30 const int32_t newAxisMask,
31 const int32_t ShrinkAxisMask,
32 const armnn::DataLayout& dataLayout)
33{
34 using namespace tflite;
35 flatbuffers::FlatBufferBuilder flatBufferBuilder;
36
Ryan OShea238ecd92023-03-07 11:44:23 +000037 flatbuffers::Offset<tflite::Buffer> buffers[6] = {
38 CreateBuffer(flatBufferBuilder),
39 CreateBuffer(flatBufferBuilder),
40 CreateBuffer(flatBufferBuilder,
41 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(beginTensorData.data()),
42 sizeof(int32_t) * beginTensorData.size())),
43 CreateBuffer(flatBufferBuilder,
44 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(endTensorData.data()),
45 sizeof(int32_t) * endTensorData.size())),
46 CreateBuffer(flatBufferBuilder,
47 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(strideTensorData.data()),
48 sizeof(int32_t) * strideTensorData.size())),
49 CreateBuffer(flatBufferBuilder)
50 };
Cathal Corbett839b9322022-11-18 08:52:18 +000051
52 std::array<flatbuffers::Offset<Tensor>, 5> tensors;
53 tensors[0] = CreateTensor(flatBufferBuilder,
54 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
55 inputTensorShape.size()),
56 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000057 1,
Cathal Corbett839b9322022-11-18 08:52:18 +000058 flatBufferBuilder.CreateString("input"));
59 tensors[1] = CreateTensor(flatBufferBuilder,
60 flatBufferBuilder.CreateVector<int32_t>(beginTensorShape.data(),
61 beginTensorShape.size()),
62 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000063 2,
Cathal Corbett839b9322022-11-18 08:52:18 +000064 flatBufferBuilder.CreateString("begin_tensor"));
65 tensors[2] = CreateTensor(flatBufferBuilder,
66 flatBufferBuilder.CreateVector<int32_t>(endTensorShape.data(),
67 endTensorShape.size()),
68 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000069 3,
Cathal Corbett839b9322022-11-18 08:52:18 +000070 flatBufferBuilder.CreateString("end_tensor"));
71 tensors[3] = CreateTensor(flatBufferBuilder,
72 flatBufferBuilder.CreateVector<int32_t>(strideTensorShape.data(),
73 strideTensorShape.size()),
74 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000075 4,
Cathal Corbett839b9322022-11-18 08:52:18 +000076 flatBufferBuilder.CreateString("stride_tensor"));
77 tensors[4] = CreateTensor(flatBufferBuilder,
78 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
79 outputTensorShape.size()),
80 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000081 5,
Cathal Corbett839b9322022-11-18 08:52:18 +000082 flatBufferBuilder.CreateString("output"));
83
84
85 // create operator
86 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_StridedSliceOptions;
87 flatbuffers::Offset<void> operatorBuiltinOptions = CreateStridedSliceOptions(flatBufferBuilder,
88 beginMask,
89 endMask,
90 ellipsisMask,
91 newAxisMask,
92 ShrinkAxisMask).Union();
93
94 const std::vector<int> operatorInputs{ 0, 1, 2, 3 };
95 const std::vector<int> operatorOutputs{ 4 };
96 flatbuffers::Offset <Operator> sliceOperator =
97 CreateOperator(flatBufferBuilder,
98 0,
99 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
100 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
101 operatorBuiltinOptionsType,
102 operatorBuiltinOptions);
103
104 const std::vector<int> subgraphInputs{ 0, 1, 2, 3 };
105 const std::vector<int> subgraphOutputs{ 4 };
106 flatbuffers::Offset <SubGraph> subgraph =
107 CreateSubGraph(flatBufferBuilder,
108 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
109 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
110 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
111 flatBufferBuilder.CreateVector(&sliceOperator, 1));
112
113 flatbuffers::Offset <flatbuffers::String> modelDescription =
114 flatBufferBuilder.CreateString("ArmnnDelegate: StridedSlice Operator Model");
115 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
116 BuiltinOperator_STRIDED_SLICE);
117
118 flatbuffers::Offset <Model> flatbufferModel =
119 CreateModel(flatBufferBuilder,
120 TFLITE_SCHEMA_VERSION,
121 flatBufferBuilder.CreateVector(&operatorCode, 1),
122 flatBufferBuilder.CreateVector(&subgraph, 1),
123 modelDescription,
Ryan OShea238ecd92023-03-07 11:44:23 +0000124 flatBufferBuilder.CreateVector(buffers, 6));
Cathal Corbett839b9322022-11-18 08:52:18 +0000125
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100126 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Cathal Corbett839b9322022-11-18 08:52:18 +0000127
128 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
129 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
130}
131
132template <typename T>
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000133void StridedSliceTestImpl(std::vector<T>& inputValues,
Cathal Corbett839b9322022-11-18 08:52:18 +0000134 std::vector<T>& expectedOutputValues,
135 std::vector<int32_t>& beginTensorData,
136 std::vector<int32_t>& endTensorData,
137 std::vector<int32_t>& strideTensorData,
138 std::vector<int32_t>& inputTensorShape,
139 std::vector<int32_t>& beginTensorShape,
140 std::vector<int32_t>& endTensorShape,
141 std::vector<int32_t>& strideTensorShape,
142 std::vector<int32_t>& outputTensorShape,
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000143 const std::vector<armnn::BackendId>& backends = {},
Cathal Corbett839b9322022-11-18 08:52:18 +0000144 const int32_t beginMask = 0,
145 const int32_t endMask = 0,
146 const int32_t ellipsisMask = 0,
147 const int32_t newAxisMask = 0,
148 const int32_t ShrinkAxisMask = 0,
149 const armnn::DataLayout& dataLayout = armnn::DataLayout::NHWC)
150{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100151 using namespace delegateTestInterpreter;
Cathal Corbett839b9322022-11-18 08:52:18 +0000152 std::vector<char> modelBuffer = CreateStridedSliceTfLiteModel(
153 ::tflite::TensorType_FLOAT32,
154 inputTensorShape,
155 beginTensorData,
156 endTensorData,
157 strideTensorData,
158 beginTensorShape,
159 endTensorShape,
160 strideTensorShape,
161 outputTensorShape,
162 beginMask,
163 endMask,
164 ellipsisMask,
165 newAxisMask,
166 ShrinkAxisMask,
167 dataLayout);
168
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100169 // Setup interpreter with just TFLite Runtime.
170 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
171 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
172 CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
173 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
174 std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
175 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Cathal Corbett839b9322022-11-18 08:52:18 +0000176
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100177 // Setup interpreter with Arm NN Delegate applied.
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000178 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, CaptureAvailableBackends(backends));
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100179 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
180 CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
181 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
182 std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
183 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
Cathal Corbett839b9322022-11-18 08:52:18 +0000184
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100185 armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
186 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputTensorShape);
Cathal Corbett839b9322022-11-18 08:52:18 +0000187
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100188 tfLiteInterpreter.Cleanup();
189 armnnInterpreter.Cleanup();
Cathal Corbett839b9322022-11-18 08:52:18 +0000190} // End of StridedSlice Test
191
192} // anonymous namespace