blob: 4a2537feec5a4698efc16b0d31c88f8a22ffcd51 [file] [log] [blame]
Jan Eilers2ffddda2021-02-03 09:14:30 +00001//
Cathal Corbett839b9322022-11-18 08:52:18 +00002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
Jan Eilers2ffddda2021-02-03 09:14:30 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
11#include <armnn/DescriptorsFwd.hpp>
12
13#include <flatbuffers/flatbuffers.h>
14#include <tensorflow/lite/interpreter.h>
15#include <tensorflow/lite/kernels/register.h>
16#include <tensorflow/lite/model.h>
17#include <tensorflow/lite/schema/schema_generated.h>
18#include <tensorflow/lite/version.h>
19
20#include <doctest/doctest.h>
21
22#include <string>
23
24namespace
25{
26
Jan Eilers2ffddda2021-02-03 09:14:30 +000027std::vector<char> CreateSliceTfLiteModel(tflite::TensorType tensorType,
28 const std::vector<int32_t>& inputTensorShape,
29 const std::vector<int32_t>& beginTensorData,
Cathal Corbett839b9322022-11-18 08:52:18 +000030 const std::vector<int32_t>& sizeTensorData,
Jan Eilers2ffddda2021-02-03 09:14:30 +000031 const std::vector<int32_t>& beginTensorShape,
Cathal Corbett839b9322022-11-18 08:52:18 +000032 const std::vector<int32_t>& sizeTensorShape,
33 const std::vector<int32_t>& outputTensorShape)
Jan Eilers2ffddda2021-02-03 09:14:30 +000034{
35 using namespace tflite;
36 flatbuffers::FlatBufferBuilder flatBufferBuilder;
37
Cathal Corbett839b9322022-11-18 08:52:18 +000038 std::array<flatbuffers::Offset<tflite::Buffer>, 3> buffers;
Jan Eilers2ffddda2021-02-03 09:14:30 +000039 buffers[0] = CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({}));
40 buffers[1] = CreateBuffer(flatBufferBuilder,
41 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(beginTensorData.data()),
42 sizeof(int32_t) * beginTensorData.size()));
43 buffers[2] = CreateBuffer(flatBufferBuilder,
Cathal Corbett839b9322022-11-18 08:52:18 +000044 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(sizeTensorData.data()),
45 sizeof(int32_t) * sizeTensorData.size()));
Jan Eilers2ffddda2021-02-03 09:14:30 +000046
Cathal Corbett839b9322022-11-18 08:52:18 +000047 std::array<flatbuffers::Offset<Tensor>, 4> tensors;
Jan Eilers2ffddda2021-02-03 09:14:30 +000048 tensors[0] = CreateTensor(flatBufferBuilder,
49 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
50 inputTensorShape.size()),
51 tensorType,
52 0,
53 flatBufferBuilder.CreateString("input"));
54 tensors[1] = CreateTensor(flatBufferBuilder,
55 flatBufferBuilder.CreateVector<int32_t>(beginTensorShape.data(),
56 beginTensorShape.size()),
57 ::tflite::TensorType_INT32,
58 1,
59 flatBufferBuilder.CreateString("begin_tensor"));
60 tensors[2] = CreateTensor(flatBufferBuilder,
Cathal Corbett839b9322022-11-18 08:52:18 +000061 flatBufferBuilder.CreateVector<int32_t>(sizeTensorShape.data(),
62 sizeTensorShape.size()),
Jan Eilers2ffddda2021-02-03 09:14:30 +000063 ::tflite::TensorType_INT32,
64 2,
Cathal Corbett839b9322022-11-18 08:52:18 +000065 flatBufferBuilder.CreateString("size_tensor"));
Jan Eilers2ffddda2021-02-03 09:14:30 +000066 tensors[3] = CreateTensor(flatBufferBuilder,
Jan Eilers2ffddda2021-02-03 09:14:30 +000067 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
68 outputTensorShape.size()),
69 tensorType,
70 0,
71 flatBufferBuilder.CreateString("output"));
72
73
74 // create operator
Cathal Corbett839b9322022-11-18 08:52:18 +000075 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_SliceOptions;
76 flatbuffers::Offset<void> operatorBuiltinOptions = CreateSliceOptions(flatBufferBuilder).Union();
Jan Eilers2ffddda2021-02-03 09:14:30 +000077
Cathal Corbett839b9322022-11-18 08:52:18 +000078 const std::vector<int> operatorInputs{ 0, 1, 2 };
79 const std::vector<int> operatorOutputs{ 3 };
Jan Eilers2ffddda2021-02-03 09:14:30 +000080 flatbuffers::Offset <Operator> sliceOperator =
Cathal Corbett839b9322022-11-18 08:52:18 +000081 CreateOperator(flatBufferBuilder,
82 0,
83 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
84 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
85 operatorBuiltinOptionsType,
86 operatorBuiltinOptions);
Jan Eilers2ffddda2021-02-03 09:14:30 +000087
Cathal Corbett839b9322022-11-18 08:52:18 +000088 const std::vector<int> subgraphInputs{ 0, 1, 2 };
89 const std::vector<int> subgraphOutputs{ 3 };
Jan Eilers2ffddda2021-02-03 09:14:30 +000090 flatbuffers::Offset <SubGraph> subgraph =
Cathal Corbett839b9322022-11-18 08:52:18 +000091 CreateSubGraph(flatBufferBuilder,
92 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
93 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
94 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
95 flatBufferBuilder.CreateVector(&sliceOperator, 1));
Jan Eilers2ffddda2021-02-03 09:14:30 +000096
97 flatbuffers::Offset <flatbuffers::String> modelDescription =
Cathal Corbett839b9322022-11-18 08:52:18 +000098 flatBufferBuilder.CreateString("ArmnnDelegate: Slice Operator Model");
Jan Eilers2ffddda2021-02-03 09:14:30 +000099 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
Cathal Corbett839b9322022-11-18 08:52:18 +0000100 BuiltinOperator_SLICE);
Jan Eilers2ffddda2021-02-03 09:14:30 +0000101
102 flatbuffers::Offset <Model> flatbufferModel =
Cathal Corbett839b9322022-11-18 08:52:18 +0000103 CreateModel(flatBufferBuilder,
104 TFLITE_SCHEMA_VERSION,
105 flatBufferBuilder.CreateVector(&operatorCode, 1),
106 flatBufferBuilder.CreateVector(&subgraph, 1),
107 modelDescription,
108 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
Jan Eilers2ffddda2021-02-03 09:14:30 +0000109
110 flatBufferBuilder.Finish(flatbufferModel);
111
112 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
113 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
114}
115
116template <typename T>
Cathal Corbett839b9322022-11-18 08:52:18 +0000117void SliceTestImpl(std::vector<armnn::BackendId>& backends,
118 std::vector<T>& inputValues,
119 std::vector<T>& expectedOutputValues,
120 std::vector<int32_t>& beginTensorData,
121 std::vector<int32_t>& sizeTensorData,
122 std::vector<int32_t>& inputTensorShape,
123 std::vector<int32_t>& beginTensorShape,
124 std::vector<int32_t>& sizeTensorShape,
125 std::vector<int32_t>& outputTensorShape)
Jan Eilers2ffddda2021-02-03 09:14:30 +0000126{
127 using namespace tflite;
128 std::vector<char> modelBuffer = CreateSliceTfLiteModel(
Cathal Corbett839b9322022-11-18 08:52:18 +0000129 ::tflite::TensorType_FLOAT32,
130 inputTensorShape,
131 beginTensorData,
132 sizeTensorData,
133 beginTensorShape,
134 sizeTensorShape,
135 outputTensorShape);
Jan Eilers2ffddda2021-02-03 09:14:30 +0000136
137 auto tfLiteModel = GetModel(modelBuffer.data());
138
139 // Create TfLite Interpreters
140 std::unique_ptr<Interpreter> armnnDelegate;
141 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
Cathal Corbett839b9322022-11-18 08:52:18 +0000142 (&armnnDelegate) == kTfLiteOk);
Jan Eilers2ffddda2021-02-03 09:14:30 +0000143 CHECK(armnnDelegate != nullptr);
144 CHECK(armnnDelegate->AllocateTensors() == kTfLiteOk);
145
146 std::unique_ptr<Interpreter> tfLiteDelegate;
147 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
Cathal Corbett839b9322022-11-18 08:52:18 +0000148 (&tfLiteDelegate) == kTfLiteOk);
Jan Eilers2ffddda2021-02-03 09:14:30 +0000149 CHECK(tfLiteDelegate != nullptr);
150 CHECK(tfLiteDelegate->AllocateTensors() == kTfLiteOk);
151
152 // Create the ArmNN Delegate
153 armnnDelegate::DelegateOptions delegateOptions(backends);
154 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
Cathal Corbett839b9322022-11-18 08:52:18 +0000155 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
156 armnnDelegate::TfLiteArmnnDelegateDelete);
Jan Eilers2ffddda2021-02-03 09:14:30 +0000157 CHECK(theArmnnDelegate != nullptr);
158
159 // Modify armnnDelegateInterpreter to use armnnDelegate
160 CHECK(armnnDelegate->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
161
162 // Set input data
163 armnnDelegate::FillInput<T>(tfLiteDelegate, 0, inputValues);
164 armnnDelegate::FillInput<T>(armnnDelegate, 0, inputValues);
165
166 // Run EnqueWorkload
167 CHECK(tfLiteDelegate->Invoke() == kTfLiteOk);
168 CHECK(armnnDelegate->Invoke() == kTfLiteOk);
169
170 // Compare output data
171 armnnDelegate::CompareOutputData<T>(tfLiteDelegate,
172 armnnDelegate,
173 outputTensorShape,
174 expectedOutputValues);
175
176 tfLiteDelegate.reset(nullptr);
177 armnnDelegate.reset(nullptr);
Cathal Corbett839b9322022-11-18 08:52:18 +0000178} // End of Slice Test
Jan Eilers2ffddda2021-02-03 09:14:30 +0000179
180} // anonymous namespace