blob: c938fad31bace78e183893cd34c2e625568d80c8 [file] [log] [blame]
Jan Eilers2ffddda2021-02-03 09:14:30 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Jan Eilers2ffddda2021-02-03 09:14:30 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
11#include <armnn/DescriptorsFwd.hpp>
12
13#include <flatbuffers/flatbuffers.h>
14#include <tensorflow/lite/interpreter.h>
15#include <tensorflow/lite/kernels/register.h>
16#include <tensorflow/lite/model.h>
Teresa Charlinad1b3d72023-03-14 12:10:28 +000017#include <schema_generated.h>
Jan Eilers2ffddda2021-02-03 09:14:30 +000018#include <tensorflow/lite/version.h>
19
20#include <doctest/doctest.h>
21
22#include <string>
23
24namespace
25{
26
Jan Eilers2ffddda2021-02-03 09:14:30 +000027std::vector<char> CreateSliceTfLiteModel(tflite::TensorType tensorType,
28 const std::vector<int32_t>& inputTensorShape,
29 const std::vector<int32_t>& beginTensorData,
Cathal Corbett839b9322022-11-18 08:52:18 +000030 const std::vector<int32_t>& sizeTensorData,
Jan Eilers2ffddda2021-02-03 09:14:30 +000031 const std::vector<int32_t>& beginTensorShape,
Cathal Corbett839b9322022-11-18 08:52:18 +000032 const std::vector<int32_t>& sizeTensorShape,
33 const std::vector<int32_t>& outputTensorShape)
Jan Eilers2ffddda2021-02-03 09:14:30 +000034{
35 using namespace tflite;
36 flatbuffers::FlatBufferBuilder flatBufferBuilder;
37
Ryan OShea238ecd92023-03-07 11:44:23 +000038 flatbuffers::Offset<tflite::Buffer> buffers[5] = {
39 CreateBuffer(flatBufferBuilder),
40 CreateBuffer(flatBufferBuilder),
41 CreateBuffer(flatBufferBuilder,
42 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(beginTensorData.data()),
43 sizeof(int32_t) * beginTensorData.size())),
44 CreateBuffer(flatBufferBuilder,
45 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(sizeTensorData.data()),
46 sizeof(int32_t) * sizeTensorData.size())),
47 CreateBuffer(flatBufferBuilder)
48 };
Jan Eilers2ffddda2021-02-03 09:14:30 +000049
Cathal Corbett839b9322022-11-18 08:52:18 +000050 std::array<flatbuffers::Offset<Tensor>, 4> tensors;
Jan Eilers2ffddda2021-02-03 09:14:30 +000051 tensors[0] = CreateTensor(flatBufferBuilder,
52 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
53 inputTensorShape.size()),
54 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000055 1,
Jan Eilers2ffddda2021-02-03 09:14:30 +000056 flatBufferBuilder.CreateString("input"));
57 tensors[1] = CreateTensor(flatBufferBuilder,
58 flatBufferBuilder.CreateVector<int32_t>(beginTensorShape.data(),
59 beginTensorShape.size()),
60 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000061 2,
Jan Eilers2ffddda2021-02-03 09:14:30 +000062 flatBufferBuilder.CreateString("begin_tensor"));
63 tensors[2] = CreateTensor(flatBufferBuilder,
Cathal Corbett839b9322022-11-18 08:52:18 +000064 flatBufferBuilder.CreateVector<int32_t>(sizeTensorShape.data(),
65 sizeTensorShape.size()),
Jan Eilers2ffddda2021-02-03 09:14:30 +000066 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000067 3,
Cathal Corbett839b9322022-11-18 08:52:18 +000068 flatBufferBuilder.CreateString("size_tensor"));
Jan Eilers2ffddda2021-02-03 09:14:30 +000069 tensors[3] = CreateTensor(flatBufferBuilder,
Jan Eilers2ffddda2021-02-03 09:14:30 +000070 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
71 outputTensorShape.size()),
72 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000073 4,
Jan Eilers2ffddda2021-02-03 09:14:30 +000074 flatBufferBuilder.CreateString("output"));
75
76
77 // create operator
Cathal Corbett839b9322022-11-18 08:52:18 +000078 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_SliceOptions;
79 flatbuffers::Offset<void> operatorBuiltinOptions = CreateSliceOptions(flatBufferBuilder).Union();
Jan Eilers2ffddda2021-02-03 09:14:30 +000080
Cathal Corbett839b9322022-11-18 08:52:18 +000081 const std::vector<int> operatorInputs{ 0, 1, 2 };
82 const std::vector<int> operatorOutputs{ 3 };
Jan Eilers2ffddda2021-02-03 09:14:30 +000083 flatbuffers::Offset <Operator> sliceOperator =
Cathal Corbett839b9322022-11-18 08:52:18 +000084 CreateOperator(flatBufferBuilder,
85 0,
86 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
87 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
88 operatorBuiltinOptionsType,
89 operatorBuiltinOptions);
Jan Eilers2ffddda2021-02-03 09:14:30 +000090
Cathal Corbett839b9322022-11-18 08:52:18 +000091 const std::vector<int> subgraphInputs{ 0, 1, 2 };
92 const std::vector<int> subgraphOutputs{ 3 };
Jan Eilers2ffddda2021-02-03 09:14:30 +000093 flatbuffers::Offset <SubGraph> subgraph =
Cathal Corbett839b9322022-11-18 08:52:18 +000094 CreateSubGraph(flatBufferBuilder,
95 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
96 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
97 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
98 flatBufferBuilder.CreateVector(&sliceOperator, 1));
Jan Eilers2ffddda2021-02-03 09:14:30 +000099
100 flatbuffers::Offset <flatbuffers::String> modelDescription =
Cathal Corbett839b9322022-11-18 08:52:18 +0000101 flatBufferBuilder.CreateString("ArmnnDelegate: Slice Operator Model");
Jan Eilers2ffddda2021-02-03 09:14:30 +0000102 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
Cathal Corbett839b9322022-11-18 08:52:18 +0000103 BuiltinOperator_SLICE);
Jan Eilers2ffddda2021-02-03 09:14:30 +0000104
105 flatbuffers::Offset <Model> flatbufferModel =
Cathal Corbett839b9322022-11-18 08:52:18 +0000106 CreateModel(flatBufferBuilder,
107 TFLITE_SCHEMA_VERSION,
108 flatBufferBuilder.CreateVector(&operatorCode, 1),
109 flatBufferBuilder.CreateVector(&subgraph, 1),
110 modelDescription,
Ryan OShea238ecd92023-03-07 11:44:23 +0000111 flatBufferBuilder.CreateVector(buffers, 5));
Jan Eilers2ffddda2021-02-03 09:14:30 +0000112
113 flatBufferBuilder.Finish(flatbufferModel);
114
115 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
116 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
117}
118
119template <typename T>
Cathal Corbett839b9322022-11-18 08:52:18 +0000120void SliceTestImpl(std::vector<armnn::BackendId>& backends,
121 std::vector<T>& inputValues,
122 std::vector<T>& expectedOutputValues,
123 std::vector<int32_t>& beginTensorData,
124 std::vector<int32_t>& sizeTensorData,
125 std::vector<int32_t>& inputTensorShape,
126 std::vector<int32_t>& beginTensorShape,
127 std::vector<int32_t>& sizeTensorShape,
128 std::vector<int32_t>& outputTensorShape)
Jan Eilers2ffddda2021-02-03 09:14:30 +0000129{
130 using namespace tflite;
131 std::vector<char> modelBuffer = CreateSliceTfLiteModel(
Cathal Corbett839b9322022-11-18 08:52:18 +0000132 ::tflite::TensorType_FLOAT32,
133 inputTensorShape,
134 beginTensorData,
135 sizeTensorData,
136 beginTensorShape,
137 sizeTensorShape,
138 outputTensorShape);
Jan Eilers2ffddda2021-02-03 09:14:30 +0000139
140 auto tfLiteModel = GetModel(modelBuffer.data());
141
142 // Create TfLite Interpreters
143 std::unique_ptr<Interpreter> armnnDelegate;
144 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
Cathal Corbett839b9322022-11-18 08:52:18 +0000145 (&armnnDelegate) == kTfLiteOk);
Jan Eilers2ffddda2021-02-03 09:14:30 +0000146 CHECK(armnnDelegate != nullptr);
147 CHECK(armnnDelegate->AllocateTensors() == kTfLiteOk);
148
149 std::unique_ptr<Interpreter> tfLiteDelegate;
150 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
Cathal Corbett839b9322022-11-18 08:52:18 +0000151 (&tfLiteDelegate) == kTfLiteOk);
Jan Eilers2ffddda2021-02-03 09:14:30 +0000152 CHECK(tfLiteDelegate != nullptr);
153 CHECK(tfLiteDelegate->AllocateTensors() == kTfLiteOk);
154
155 // Create the ArmNN Delegate
156 armnnDelegate::DelegateOptions delegateOptions(backends);
157 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
Cathal Corbett839b9322022-11-18 08:52:18 +0000158 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
159 armnnDelegate::TfLiteArmnnDelegateDelete);
Jan Eilers2ffddda2021-02-03 09:14:30 +0000160 CHECK(theArmnnDelegate != nullptr);
161
162 // Modify armnnDelegateInterpreter to use armnnDelegate
163 CHECK(armnnDelegate->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
164
165 // Set input data
166 armnnDelegate::FillInput<T>(tfLiteDelegate, 0, inputValues);
167 armnnDelegate::FillInput<T>(armnnDelegate, 0, inputValues);
168
169 // Run EnqueWorkload
170 CHECK(tfLiteDelegate->Invoke() == kTfLiteOk);
171 CHECK(armnnDelegate->Invoke() == kTfLiteOk);
172
173 // Compare output data
174 armnnDelegate::CompareOutputData<T>(tfLiteDelegate,
175 armnnDelegate,
176 outputTensorShape,
177 expectedOutputValues);
178
179 tfLiteDelegate.reset(nullptr);
180 armnnDelegate.reset(nullptr);
Cathal Corbett839b9322022-11-18 08:52:18 +0000181} // End of Slice Test
Jan Eilers2ffddda2021-02-03 09:14:30 +0000182
183} // anonymous namespace