blob: 0ae6384dab8496b08c3dd2776ec423578676aa16 [file] [log] [blame]
Cathal Corbett839b9322022-11-18 08:52:18 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Cathal Corbett839b9322022-11-18 08:52:18 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Cathal Corbett839b9322022-11-18 08:52:18 +000012
13#include <flatbuffers/flatbuffers.h>
Cathal Corbett839b9322022-11-18 08:52:18 +000014#include <tensorflow/lite/kernels/register.h>
Cathal Corbett839b9322022-11-18 08:52:18 +000015#include <tensorflow/lite/version.h>
16
Matthew Sloyanebe392d2023-03-30 10:12:08 +010017#include <doctest/doctest.h>
Cathal Corbett839b9322022-11-18 08:52:18 +000018
19namespace
20{
21
22std::vector<char> CreateStridedSliceTfLiteModel(tflite::TensorType tensorType,
23 const std::vector<int32_t>& inputTensorShape,
24 const std::vector<int32_t>& beginTensorData,
25 const std::vector<int32_t>& endTensorData,
26 const std::vector<int32_t>& strideTensorData,
27 const std::vector<int32_t>& beginTensorShape,
28 const std::vector<int32_t>& endTensorShape,
29 const std::vector<int32_t>& strideTensorShape,
30 const std::vector<int32_t>& outputTensorShape,
31 const int32_t beginMask,
32 const int32_t endMask,
33 const int32_t ellipsisMask,
34 const int32_t newAxisMask,
35 const int32_t ShrinkAxisMask,
36 const armnn::DataLayout& dataLayout)
37{
38 using namespace tflite;
39 flatbuffers::FlatBufferBuilder flatBufferBuilder;
40
Ryan OShea238ecd92023-03-07 11:44:23 +000041 flatbuffers::Offset<tflite::Buffer> buffers[6] = {
42 CreateBuffer(flatBufferBuilder),
43 CreateBuffer(flatBufferBuilder),
44 CreateBuffer(flatBufferBuilder,
45 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(beginTensorData.data()),
46 sizeof(int32_t) * beginTensorData.size())),
47 CreateBuffer(flatBufferBuilder,
48 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(endTensorData.data()),
49 sizeof(int32_t) * endTensorData.size())),
50 CreateBuffer(flatBufferBuilder,
51 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(strideTensorData.data()),
52 sizeof(int32_t) * strideTensorData.size())),
53 CreateBuffer(flatBufferBuilder)
54 };
Cathal Corbett839b9322022-11-18 08:52:18 +000055
56 std::array<flatbuffers::Offset<Tensor>, 5> tensors;
57 tensors[0] = CreateTensor(flatBufferBuilder,
58 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
59 inputTensorShape.size()),
60 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000061 1,
Cathal Corbett839b9322022-11-18 08:52:18 +000062 flatBufferBuilder.CreateString("input"));
63 tensors[1] = CreateTensor(flatBufferBuilder,
64 flatBufferBuilder.CreateVector<int32_t>(beginTensorShape.data(),
65 beginTensorShape.size()),
66 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000067 2,
Cathal Corbett839b9322022-11-18 08:52:18 +000068 flatBufferBuilder.CreateString("begin_tensor"));
69 tensors[2] = CreateTensor(flatBufferBuilder,
70 flatBufferBuilder.CreateVector<int32_t>(endTensorShape.data(),
71 endTensorShape.size()),
72 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000073 3,
Cathal Corbett839b9322022-11-18 08:52:18 +000074 flatBufferBuilder.CreateString("end_tensor"));
75 tensors[3] = CreateTensor(flatBufferBuilder,
76 flatBufferBuilder.CreateVector<int32_t>(strideTensorShape.data(),
77 strideTensorShape.size()),
78 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000079 4,
Cathal Corbett839b9322022-11-18 08:52:18 +000080 flatBufferBuilder.CreateString("stride_tensor"));
81 tensors[4] = CreateTensor(flatBufferBuilder,
82 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
83 outputTensorShape.size()),
84 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000085 5,
Cathal Corbett839b9322022-11-18 08:52:18 +000086 flatBufferBuilder.CreateString("output"));
87
88
89 // create operator
90 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_StridedSliceOptions;
91 flatbuffers::Offset<void> operatorBuiltinOptions = CreateStridedSliceOptions(flatBufferBuilder,
92 beginMask,
93 endMask,
94 ellipsisMask,
95 newAxisMask,
96 ShrinkAxisMask).Union();
97
98 const std::vector<int> operatorInputs{ 0, 1, 2, 3 };
99 const std::vector<int> operatorOutputs{ 4 };
100 flatbuffers::Offset <Operator> sliceOperator =
101 CreateOperator(flatBufferBuilder,
102 0,
103 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
104 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
105 operatorBuiltinOptionsType,
106 operatorBuiltinOptions);
107
108 const std::vector<int> subgraphInputs{ 0, 1, 2, 3 };
109 const std::vector<int> subgraphOutputs{ 4 };
110 flatbuffers::Offset <SubGraph> subgraph =
111 CreateSubGraph(flatBufferBuilder,
112 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
113 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
114 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
115 flatBufferBuilder.CreateVector(&sliceOperator, 1));
116
117 flatbuffers::Offset <flatbuffers::String> modelDescription =
118 flatBufferBuilder.CreateString("ArmnnDelegate: StridedSlice Operator Model");
119 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
120 BuiltinOperator_STRIDED_SLICE);
121
122 flatbuffers::Offset <Model> flatbufferModel =
123 CreateModel(flatBufferBuilder,
124 TFLITE_SCHEMA_VERSION,
125 flatBufferBuilder.CreateVector(&operatorCode, 1),
126 flatBufferBuilder.CreateVector(&subgraph, 1),
127 modelDescription,
Ryan OShea238ecd92023-03-07 11:44:23 +0000128 flatBufferBuilder.CreateVector(buffers, 6));
Cathal Corbett839b9322022-11-18 08:52:18 +0000129
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100130 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Cathal Corbett839b9322022-11-18 08:52:18 +0000131
132 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
133 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
134}
135
136template <typename T>
137void StridedSliceTestImpl(std::vector<armnn::BackendId>& backends,
138 std::vector<T>& inputValues,
139 std::vector<T>& expectedOutputValues,
140 std::vector<int32_t>& beginTensorData,
141 std::vector<int32_t>& endTensorData,
142 std::vector<int32_t>& strideTensorData,
143 std::vector<int32_t>& inputTensorShape,
144 std::vector<int32_t>& beginTensorShape,
145 std::vector<int32_t>& endTensorShape,
146 std::vector<int32_t>& strideTensorShape,
147 std::vector<int32_t>& outputTensorShape,
148 const int32_t beginMask = 0,
149 const int32_t endMask = 0,
150 const int32_t ellipsisMask = 0,
151 const int32_t newAxisMask = 0,
152 const int32_t ShrinkAxisMask = 0,
153 const armnn::DataLayout& dataLayout = armnn::DataLayout::NHWC)
154{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100155 using namespace delegateTestInterpreter;
Cathal Corbett839b9322022-11-18 08:52:18 +0000156 std::vector<char> modelBuffer = CreateStridedSliceTfLiteModel(
157 ::tflite::TensorType_FLOAT32,
158 inputTensorShape,
159 beginTensorData,
160 endTensorData,
161 strideTensorData,
162 beginTensorShape,
163 endTensorShape,
164 strideTensorShape,
165 outputTensorShape,
166 beginMask,
167 endMask,
168 ellipsisMask,
169 newAxisMask,
170 ShrinkAxisMask,
171 dataLayout);
172
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100173 // Setup interpreter with just TFLite Runtime.
174 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
175 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
176 CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
177 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
178 std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
179 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Cathal Corbett839b9322022-11-18 08:52:18 +0000180
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100181 // Setup interpreter with Arm NN Delegate applied.
182 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
183 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
184 CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
185 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
186 std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
187 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
Cathal Corbett839b9322022-11-18 08:52:18 +0000188
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100189 armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
190 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputTensorShape);
Cathal Corbett839b9322022-11-18 08:52:18 +0000191
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100192 tfLiteInterpreter.Cleanup();
193 armnnInterpreter.Cleanup();
Cathal Corbett839b9322022-11-18 08:52:18 +0000194} // End of StridedSlice Test
195
196} // anonymous namespace