blob: d3d160158bb8cea379b8f8019b1f8fb8dd898ed4 [file] [log] [blame]
Cathal Corbett839b9322022-11-18 08:52:18 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Cathal Corbett839b9322022-11-18 08:52:18 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Cathal Corbett839b9322022-11-18 08:52:18 +000012
13#include <flatbuffers/flatbuffers.h>
Cathal Corbett839b9322022-11-18 08:52:18 +000014#include <tensorflow/lite/kernels/register.h>
Cathal Corbett839b9322022-11-18 08:52:18 +000015#include <tensorflow/lite/version.h>
16
Matthew Sloyanebe392d2023-03-30 10:12:08 +010017#include <schema_generated.h>
Cathal Corbett839b9322022-11-18 08:52:18 +000018
Matthew Sloyanebe392d2023-03-30 10:12:08 +010019#include <doctest/doctest.h>
Cathal Corbett839b9322022-11-18 08:52:18 +000020
21namespace
22{
23
24std::vector<char> CreateStridedSliceTfLiteModel(tflite::TensorType tensorType,
25 const std::vector<int32_t>& inputTensorShape,
26 const std::vector<int32_t>& beginTensorData,
27 const std::vector<int32_t>& endTensorData,
28 const std::vector<int32_t>& strideTensorData,
29 const std::vector<int32_t>& beginTensorShape,
30 const std::vector<int32_t>& endTensorShape,
31 const std::vector<int32_t>& strideTensorShape,
32 const std::vector<int32_t>& outputTensorShape,
33 const int32_t beginMask,
34 const int32_t endMask,
35 const int32_t ellipsisMask,
36 const int32_t newAxisMask,
37 const int32_t ShrinkAxisMask,
38 const armnn::DataLayout& dataLayout)
39{
40 using namespace tflite;
41 flatbuffers::FlatBufferBuilder flatBufferBuilder;
42
Ryan OShea238ecd92023-03-07 11:44:23 +000043 flatbuffers::Offset<tflite::Buffer> buffers[6] = {
44 CreateBuffer(flatBufferBuilder),
45 CreateBuffer(flatBufferBuilder),
46 CreateBuffer(flatBufferBuilder,
47 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(beginTensorData.data()),
48 sizeof(int32_t) * beginTensorData.size())),
49 CreateBuffer(flatBufferBuilder,
50 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(endTensorData.data()),
51 sizeof(int32_t) * endTensorData.size())),
52 CreateBuffer(flatBufferBuilder,
53 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(strideTensorData.data()),
54 sizeof(int32_t) * strideTensorData.size())),
55 CreateBuffer(flatBufferBuilder)
56 };
Cathal Corbett839b9322022-11-18 08:52:18 +000057
58 std::array<flatbuffers::Offset<Tensor>, 5> tensors;
59 tensors[0] = CreateTensor(flatBufferBuilder,
60 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
61 inputTensorShape.size()),
62 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000063 1,
Cathal Corbett839b9322022-11-18 08:52:18 +000064 flatBufferBuilder.CreateString("input"));
65 tensors[1] = CreateTensor(flatBufferBuilder,
66 flatBufferBuilder.CreateVector<int32_t>(beginTensorShape.data(),
67 beginTensorShape.size()),
68 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000069 2,
Cathal Corbett839b9322022-11-18 08:52:18 +000070 flatBufferBuilder.CreateString("begin_tensor"));
71 tensors[2] = CreateTensor(flatBufferBuilder,
72 flatBufferBuilder.CreateVector<int32_t>(endTensorShape.data(),
73 endTensorShape.size()),
74 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000075 3,
Cathal Corbett839b9322022-11-18 08:52:18 +000076 flatBufferBuilder.CreateString("end_tensor"));
77 tensors[3] = CreateTensor(flatBufferBuilder,
78 flatBufferBuilder.CreateVector<int32_t>(strideTensorShape.data(),
79 strideTensorShape.size()),
80 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000081 4,
Cathal Corbett839b9322022-11-18 08:52:18 +000082 flatBufferBuilder.CreateString("stride_tensor"));
83 tensors[4] = CreateTensor(flatBufferBuilder,
84 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
85 outputTensorShape.size()),
86 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000087 5,
Cathal Corbett839b9322022-11-18 08:52:18 +000088 flatBufferBuilder.CreateString("output"));
89
90
91 // create operator
92 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_StridedSliceOptions;
93 flatbuffers::Offset<void> operatorBuiltinOptions = CreateStridedSliceOptions(flatBufferBuilder,
94 beginMask,
95 endMask,
96 ellipsisMask,
97 newAxisMask,
98 ShrinkAxisMask).Union();
99
100 const std::vector<int> operatorInputs{ 0, 1, 2, 3 };
101 const std::vector<int> operatorOutputs{ 4 };
102 flatbuffers::Offset <Operator> sliceOperator =
103 CreateOperator(flatBufferBuilder,
104 0,
105 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
106 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
107 operatorBuiltinOptionsType,
108 operatorBuiltinOptions);
109
110 const std::vector<int> subgraphInputs{ 0, 1, 2, 3 };
111 const std::vector<int> subgraphOutputs{ 4 };
112 flatbuffers::Offset <SubGraph> subgraph =
113 CreateSubGraph(flatBufferBuilder,
114 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
115 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
116 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
117 flatBufferBuilder.CreateVector(&sliceOperator, 1));
118
119 flatbuffers::Offset <flatbuffers::String> modelDescription =
120 flatBufferBuilder.CreateString("ArmnnDelegate: StridedSlice Operator Model");
121 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
122 BuiltinOperator_STRIDED_SLICE);
123
124 flatbuffers::Offset <Model> flatbufferModel =
125 CreateModel(flatBufferBuilder,
126 TFLITE_SCHEMA_VERSION,
127 flatBufferBuilder.CreateVector(&operatorCode, 1),
128 flatBufferBuilder.CreateVector(&subgraph, 1),
129 modelDescription,
Ryan OShea238ecd92023-03-07 11:44:23 +0000130 flatBufferBuilder.CreateVector(buffers, 6));
Cathal Corbett839b9322022-11-18 08:52:18 +0000131
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100132 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Cathal Corbett839b9322022-11-18 08:52:18 +0000133
134 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
135 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
136}
137
138template <typename T>
139void StridedSliceTestImpl(std::vector<armnn::BackendId>& backends,
140 std::vector<T>& inputValues,
141 std::vector<T>& expectedOutputValues,
142 std::vector<int32_t>& beginTensorData,
143 std::vector<int32_t>& endTensorData,
144 std::vector<int32_t>& strideTensorData,
145 std::vector<int32_t>& inputTensorShape,
146 std::vector<int32_t>& beginTensorShape,
147 std::vector<int32_t>& endTensorShape,
148 std::vector<int32_t>& strideTensorShape,
149 std::vector<int32_t>& outputTensorShape,
150 const int32_t beginMask = 0,
151 const int32_t endMask = 0,
152 const int32_t ellipsisMask = 0,
153 const int32_t newAxisMask = 0,
154 const int32_t ShrinkAxisMask = 0,
155 const armnn::DataLayout& dataLayout = armnn::DataLayout::NHWC)
156{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100157 using namespace delegateTestInterpreter;
Cathal Corbett839b9322022-11-18 08:52:18 +0000158 std::vector<char> modelBuffer = CreateStridedSliceTfLiteModel(
159 ::tflite::TensorType_FLOAT32,
160 inputTensorShape,
161 beginTensorData,
162 endTensorData,
163 strideTensorData,
164 beginTensorShape,
165 endTensorShape,
166 strideTensorShape,
167 outputTensorShape,
168 beginMask,
169 endMask,
170 ellipsisMask,
171 newAxisMask,
172 ShrinkAxisMask,
173 dataLayout);
174
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100175 // Setup interpreter with just TFLite Runtime.
176 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
177 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
178 CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
179 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
180 std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
181 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Cathal Corbett839b9322022-11-18 08:52:18 +0000182
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100183 // Setup interpreter with Arm NN Delegate applied.
184 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
185 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
186 CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
187 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
188 std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
189 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
Cathal Corbett839b9322022-11-18 08:52:18 +0000190
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100191 armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
192 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputTensorShape);
Cathal Corbett839b9322022-11-18 08:52:18 +0000193
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100194 tfLiteInterpreter.Cleanup();
195 armnnInterpreter.Cleanup();
Cathal Corbett839b9322022-11-18 08:52:18 +0000196} // End of StridedSlice Test
197
198} // anonymous namespace